In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from timeit import default_timer as timer

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

from util.collision_loss import torch_collision_check, NN_constraint_step
from util.zonotope import Zonotope, TorchZonotope
from util.constrained_zonotope import TorchConstrainedZonotope
from util.NN_con_zono import forward_pass_NN_torch, forward_pass_NN_con_zono_torch

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

In [None]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(2, 10)  
        self.fc2 = nn.Linear(10, 2)  

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
net.to(device)

## Load saved parameters

In [None]:
net.load_state_dict(torch.load("log/constrained_1_1000its.pt"))
#net.load_state_dict(torch.load("log/constrained_2_400its.pt"))
net.eval()

# Constrained training
Train function approximator under some obstacle constraints

In [None]:
# input zonotope
Z_in = TorchZonotope(torch.zeros(2,1).to(device),torch.eye(2).to(device))

# output constraint zonotope ("obstacle")
c_obs = torch.tensor([[1.5],[1.5]]).to(device)
G_obs = torch.diag(torch.tensor([0.5,0.5]).to(device))
Z_obs = TorchConstrainedZonotope(c_obs, G_obs)

In [None]:
# recompute reachable set
Z_out = forward_pass_NN_torch(Z_in, net)

# plot zonotopes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10))
for z in Z_out:
    z.plot(ax1)
Z_obs.plot(ax1, 'r')
# plot samples
N_samples = 10000
X_in = np.random.uniform(-1, 1, (N_samples,2))
X_in = torch.as_tensor(X_in, dtype=torch.float).to(device)
Y_out = net(X_in)
color_vec = torch.sum(X_in,dim=1).cpu()
ax2.scatter(Y_out[:,0].cpu().detach().numpy(), Y_out[:,1].cpu().detach().numpy(),c=color_vec,cmap="gist_rainbow")

In [None]:
values = []
losses = []
for z in Z_out:
    v = torch_collision_check(z, Z_obs)
    if v <= 1: # in collision
        print("in collision")
        values.append(v.item())
        loss = torch.square(1 - v)
        losses.append(loss.item())

print("collision values: ",values)
print("losses: ",losses)
print("total loss: ",sum(losses))