In [1]:
from data.Safe2Unsafe import DeepAccidentDataset
from method.ResNet import InDCBFTrainer, InDCBFController, Barrier
import torch
import time
import pytorch_lightning as pl
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
data = DeepAccidentDataset(train_batch_size=32,val_batch_size=32,num_workers=16)
data.setup()
train_dataloader = data.train_dataloader()
test_dataloader = data.val_dataloader()
for idx, (i,u,label) in enumerate(test_dataloader):
    break

In [3]:
class VAE(torch.nn.Module):
    def __init__(self,latent_dim,n_control=2,model="stabilityai/sd-vae-ft-mse",hidden_dim=4*28*28,num_cam=6,freeze_ViT=True):
        super(VAE, self).__init__()
        from diffusers.models import AutoencoderKL
        self.vae = AutoencoderKL.from_pretrained(model)
        if freeze_ViT:
            for n,p in self.vae.named_parameters():
                p.requires_grad = False
        self.proj = torch.nn.Linear(hidden_dim,latent_dim)
        self.rec = torch.nn.Linear(latent_dim,hidden_dim)
        self.fusion = torch.nn.Linear(2*latent_dim+n_control,latent_dim)
        self.num_cam = num_cam
    
    def forward(self,imgs,x_p,u_p):
        B,N,C,H,W = imgs.shape
        rep = self.vae.encode(imgs.reshape(-1,C,H,W))['latent_dist'].mode().reshape(B,N,-1)
        rep = self.proj(rep)
        rep = torch.cat([rep,x_p,u_p.unsqueeze(1).expand(-1,N,-1)],dim=-1)
        return self.fusion(rep)

    def encode(self,imgs,x_p,u_p):
        return self.forward(imgs,x_p,u_p)

    def decode(self,x,trajectory=True):
        if trajectory:
            B,T,N,H = x.shape
            rep = self.rec(x)
            rep = rep.reshape(B*N*T,4,28,28)
            return self.vae.decoder(rep).reshape(B,T,N,3,224,224)
        else:
            B,N,H = x.shape
            rep = self.rec(x)
            rep = rep.reshape(B*N,4,28,28)
            return self.vae.decoder(rep).reshape(B,N,3,224,224)

In [4]:
vae = VAE(16,model="stabilityai/sd-vae-ft-mse").cuda()
x_init = torch.zeros(i.shape[0],6,16).cuda()
u_init = torch.cat([u[:,0,:].unsqueeze(1),u],dim=1).cuda()

In [5]:
x = vae(i[:2,0,:].cuda(),x_init[:2],u_init[:2,0])

In [7]:
x.shape

torch.Size([2, 6, 16])

In [6]:
rec = vae.decode(x,trajectory=False)

In [8]:
from method.utils import build_mlp, NeuralODE
latent_dim = 16
h_dim = 256
n_control = 2
ode = NeuralODE([latent_dim,h_dim,h_dim,latent_dim],
                             [latent_dim,h_dim,h_dim,latent_dim*n_control]).cuda()

In [9]:
x.shape

torch.Size([2, 6, 16])

In [10]:
u[:,0].shape

torch.Size([32, 2])

In [15]:
from torchdiffeq import odeint
def odefunc(t,state):
    f, g = ode(state)
    gu = torch.einsum("bnha,bna->bnh",g.view(g.shape[0],g.shape[1],-1,n_control),u_init[:2,0].unsqueeze(1).expand(-1,6,-1))
    return f + gu
timesteps = torch.Tensor([0,0.1]).cuda()
x_tide = odeint(odefunc,x,timesteps,rtol=5e-6)[1,:,:]

In [7]:
import torch.nn.functional as F

class Barrier(torch.nn.Module):
    def __init__(self,
                 n_control,
                 latent_dim,
                 num_cam = 6,
                 h_dim = 64,
                 eps_safe = 1,
                 eps_unsafe = 1,
                 eps_ascent = 1,
                 eps_descent = 1,
                 w_safe=1,
                 w_unsafe=1,
                 w_grad=1,
                 w_non_zero=1,
                 w_lambda=1,
                 with_gradient=False,
                 with_nonzero=False,
                 **kwargs
                 ):
        super(Barrier, self).__init__()
        modules = []
        # hidden_dims = [latent_dim,h_dim,h_dim,h_dim,1]
        hidden_dims = [latent_dim,h_dim,h_dim,1]
        for i in range(len(hidden_dims)-1):
            modules.append(torch.nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            if not i == len(hidden_dims)-2:
                modules.append(torch.nn.ReLU())
        modules.append(torch.nn.Tanh())
        self.attention = torch.nn.parameter.Parameter(torch.rand((num_cam,latent_dim)))
        self.cbf = torch.nn.Sequential(*modules)
        self.n_control = n_control
        self.eps_safe = eps_safe
        self.eps_unsafe = eps_unsafe
        self.eps_ascent = eps_ascent
        self.eps_descent = eps_descent
        self.w_safe = w_safe
        self.w_unsafe = w_unsafe
        self.w_grad = w_grad
        self.w_non_zero = w_non_zero
        self.w_lambda = w_lambda
        self.with_gradient = with_gradient
        self.with_nonzero = with_nonzero

    def forward(self,x):
        B,T,N,H = x.shape
        weight = torch.einsum("btnh,btnh->btn",self.attention.expand(B,T,-1,-1),x)
        x = torch.einsum("btn,btnh->btnh",weight,x).sum(2)
        return self.cbf(x)

    def loss_function(self,x,label,u,ode):
        # x = x.detach()
        label = label.squeeze(dim=-1)
        x_safe = x[label == 0]
        x_unsafe = x[label == 1]
        b_safe = self.forward(x_safe)
        b_unsafe = self.forward(x_unsafe)
        eps_safe = self.eps_safe*torch.ones_like(b_safe)
        eps_unsafe = self.eps_unsafe*torch.ones_like(b_unsafe)
        loss_1 = F.relu(eps_safe-b_safe).sum(dim=-1).mean()
        loss_2 = F.relu(eps_unsafe+b_unsafe).sum(dim=-1).mean()
        output = {"loss_safe":self.w_safe*loss_1,"loss_unsafe":self.w_unsafe*loss_2,"b_safe":b_safe.mean(),"b_unsafe":b_unsafe.mean()}
        x_g = x_safe.clone().detach()
        x_g.requires_grad = True
        b = self.forward(x_g)
        d_b_safe = torch.autograd.grad(b.mean(),x_g,retain_graph=True)[0]
        with torch.no_grad():
            f, g = ode(x_g)
        gu = torch.einsum('btha,bta->bth',g.view(g.shape[0],g.shape[1],f.shape[-1],self.n_control),u[label == 0])
        ascent_value = torch.einsum('bth,bth->bt', d_b_safe, (f + gu))
        loss_3 = F.relu(self.eps_ascent - ascent_value.unsqueeze(-1) - b_safe).sum(dim=-1).mean()
        output['loss_grad_ascent'] = self.w_grad*loss_3
        output['b_grad_ascent'] = (ascent_value.unsqueeze(-1) + b_safe).mean()
        return output

In [8]:
barrier = Barrier()

torch.Size([2, 6, 16])