In [1]:
import sys
import torch
sys.path.append("../src")
from interpret import Abs_interpreter
from domain import Interval

In [2]:
# Class to handle interval domain operations
class IntervalDomainHandler():
    def add(self, i1, i2):
        return (i1[0]+i2[0], i1[1]+i2[1])
    
    def subtract(self, i1, i2):
        return (i1[0]-i2[1], i1[1]-i2[0])
    
    def mul(self, i1, i2):
        op1 = i1[0] * i2[0]
        op2 = i1[0] * i2[1]
        op3 = i1[1] * i2[0]
        op4 = i1[1] * i2[1]
        
        return (min(op1, op2, op3, op4), max(op1, op2, op3, op4))

In [5]:
def compute_pinn_burgers_residual_bounds(model_name, 
                                         inpx_range = (-1 + 0.001, 1 - 0.001), 
                                         inpt_range = (0 + 0.001, 2 - 0.001), 
                                         abstract_domain = 'Interval'):
    """
    Takes in a pinn burgers model, input range and the abstract domain.
    Computes the residual bound of the model on that input range using the specified abstract domain.
    """
    model = torch.load(model_name)
    
    if abstract_domain == 'Interval':
        domain = Interval()
    else:
        raise NotImplementedError("Unknown abstract domain: " + domain)
        
    abs_interpreter = Abs_interpreter(model, domain)
    lows = torch.tensor([inpx_range[0], inpt_range[0]]).reshape(-1,1)
    highs = torch.tensor([inpx_range[1], inpt_range[1]]).reshape(-1,1)
    val_bounds, deriv_bounds = abs_interpreter.forward_pass(lows, highs)

    ul = val_bounds[0][0][0].item()
    ur = val_bounds[1][0][0].item()
    
    uxl = deriv_bounds[0][0][0].item()
    utl = deriv_bounds[0][0][1].item()
    
    uxr = deriv_bounds[1][0][0].item()
    utr = deriv_bounds[1][0][1].item()
        
    interval_handler = IntervalDomainHandler()
    
    # ut + u*ux
    residual_bounds = interval_handler.add((utl, utr), interval_handler.mul((ul, ur), (uxl, uxr)))
    return residual_bounds

## Computing bounds using input splitting

In [8]:
def compute_pinn_bounds_using_input_splitting(model_name, 
                                              num_partitions, 
                                              inpx_range = (-1 + 0.001, 1 - 0.001), 
                                              inpt_range = (0 + 0.001, 2 - 0.001), 
                                              abstract_domain = 'Interval'):
    """
    Compute pinn_burger bounds on the specified input range by partitioning into the specified
    number of partitions.
    """
    xeps = (inpx_range[1] - inpx_range[0])/num_partitions
    teps = (inpt_range[1] - inpt_range[0])/num_partitions

    l_final = None
    u_final = None

    for i in range(num_partitions):
        for j in range(num_partitions):
            bounds = compute_pinn_burgers_residual_bounds(model_name, 
                                                          inpx_range=(inpx_range[0] + xeps*i, inpx_range[0] + xeps*(i+1)),
                                                          inpt_range=(inpt_range[0] + teps*j, inpt_range[0] + teps*(j+1)),
                                                          abstract_domain = abstract_domain)

            if l_final is None: l_final = bounds[0]
            if u_final is None: u_final = bounds[1]

            l_final = min(l_final, bounds[0])
            u_final = max(u_final, bounds[1])

    return l_final, u_final

In [11]:
print(compute_pinn_bounds_using_input_splitting("../trained_models/pinn-burgers.pt", 20))

(-1547.4691210356614, 1655.8961358222296)


In [12]:
print(compute_pinn_bounds_using_input_splitting("../trained_models/pinn-burgers_weak.pt", 20))

(-10.846220347932103, 10.029506655381965)
