In [3]:
from data.Safe2Unsafe import DeepAccidentDataset
from method.dynamics import SABLASDynamics
from method.barriers import SABLASBarrier
from method.trainers import SABLASTrainer
import torch
import time
import pytorch_lightning as pl
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [4]:
data = DeepAccidentDataset(train_batch_size=32,val_batch_size=32,num_workers=16)
data.setup()
train_dataloader = data.train_dataloader()
test_dataloader = data.val_dataloader()

In [5]:
latent_dim = 16
barrier = SABLASBarrier(2,latent_dim=latent_dim)
# model = SABLASDynamics(2,"cuda",model="google/vit-base-patch16-224",latent_dim=latent_dim)
# model = SABLASDynamics(2,"cuda",model="openai/clip-vit-base-patch16",latent_dim=latent_dim)
model = SABLASDynamics(2,"cuda",model="resnet50",latent_dim=latent_dim)
trainer = SABLASTrainer(model,barrier)
checkpoint = torch.load("/root/tf-logs/SABLAS/version_3/checkpoints/last.ckpt")
trainer.load_state_dict(checkpoint['state_dict'])
trainer.to("cuda")

SABLASTrainer(
  (model): SABLASDynamics(
    (encoder): ResNetEncoder(
      (ResNet): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affi

In [6]:
b_all = []
label_all = []
trainer.eval()
for idx, (i,u,label) in enumerate(test_dataloader):
    i, u = i.to("cuda"), u.to("cuda")
    x,x_tide = trainer.model.simulate(i,u)
    b = trainer.barrier(x).squeeze(-1)
    b_all.append(b.cpu())
    label_all.append(label.squeeze(-1))
import numpy as np
bs = torch.cat(b_all)
labels = torch.cat(label_all)
results = torch.cat([bs,labels.unsqueeze(-1)],dim=-1).detach().numpy()
np.savetxt("./results_sablas_resnet.txt",results)

100%|██████████| 4/4 [00:00<00:00, 10.55it/s]
100%|██████████| 4/4 [00:00<00:00, 10.59it/s]
100%|██████████| 4/4 [00:00<00:00, 10.34it/s]
100%|██████████| 4/4 [00:00<00:00, 10.53it/s]
100%|██████████| 4/4 [00:00<00:00, 10.54it/s]
100%|██████████| 4/4 [00:00<00:00, 10.55it/s]
100%|██████████| 4/4 [00:00<00:00, 10.47it/s]
100%|██████████| 4/4 [00:00<00:00, 10.58it/s]
100%|██████████| 4/4 [00:00<00:00, 10.57it/s]
100%|██████████| 4/4 [00:00<00:00, 10.56it/s]
100%|██████████| 4/4 [00:00<00:00,  7.88it/s]
100%|██████████| 4/4 [00:00<00:00,  6.74it/s]
100%|██████████| 4/4 [00:00<00:00,  7.67it/s]
100%|██████████| 4/4 [00:01<00:00,  3.70it/s]
100%|██████████| 4/4 [00:01<00:00,  3.84it/s]
100%|██████████| 4/4 [00:00<00:00,  4.69it/s]
100%|██████████| 4/4 [00:00<00:00, 10.21it/s]
100%|██████████| 4/4 [00:00<00:00, 10.54it/s]
100%|██████████| 4/4 [00:00<00:00, 10.51it/s]
100%|██████████| 4/4 [00:00<00:00, 10.54it/s]
100%|██████████| 4/4 [00:00<00:00, 10.54it/s]
100%|██████████| 4/4 [00:00<00:00,

In [10]:
results = np.loadtxt("./results_sablas_resnet.txt")
regular = results[results[:,-1] == 0,:-1]
collision = results[results[:,-1] == 1,:-1]
acc_regular = (regular > 0).mean()
acc_collision = (collision < 0).mean()
print(acc_regular,acc_collision)

0.9426490066225166 0.39555555555555555


In [8]:
b_all = []
label_all = []
trainer.eval()
for idx, (i,u,label) in enumerate(test_dataloader):
    i, u = i.to("cuda"), u.to("cuda")
    x,x_tide = trainer.model.simulate(i,u)
    x_tide = x[:,1:]
    x = x[:,:-1]
    b = trainer.barrier(x)
    f,g = trainer.model.ode(x)
    gu = torch.einsum('btha,bta->bth',g.view(g.shape[0],g.shape[1],f.shape[-1],2),u[:,:-1])
    derive = f + gu
    x_nom = x + derive*0.1
    x_nom = x_nom + (x_tide-x_nom).detach()
    ascent_value = b + (trainer.barrier(x_nom)-b)/0.1
    b_all.append(ascent_value.cpu())
    label_all.append(label.squeeze(-1))
bs = torch.cat(b_all)
labels = torch.cat(label_all)
results = torch.cat([bs.squeeze(-1),labels.unsqueeze(-1)],dim=-1).detach().numpy()
np.savetxt("./results_sablas_grad_resnet.txt",results)

100%|██████████| 4/4 [00:00<00:00, 10.22it/s]
100%|██████████| 4/4 [00:00<00:00, 10.03it/s]
100%|██████████| 4/4 [00:00<00:00, 10.06it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.03it/s]
100%|██████████| 4/4 [00:00<00:00, 10.04it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.04it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.03it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.04it/s]
100%|██████████| 4/4 [00:00<00:00, 10.03it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00,  5.18it/s]
100%|██████████| 4/4 [00:00<00:00,  6.72it/s]
100%|██████████| 4/4 [00:00<00:00, 10.42it/s]
100%|██████████| 4/4 [00:00<00:00, 10.03it/s]
100%|██████████| 4/4 [00:00<00:00, 10.05it/s]
100%|██████████| 4/4 [00:00<00:00,

In [9]:
results = np.loadtxt("./results_sablas_grad_resnet.txt")
regular = results[results[:,-1] == 0,:-1]
collision = results[results[:,-1] == 1,:-1]
acc_regular = (regular > 0).mean()
acc_collision = (collision < 0).mean()
print(acc_regular,acc_collision)

0.930794701986755 0.39444444444444443
