In [1]:
import wandb
import torch
from torch.nn import MSELoss
from traj_dataset import TrajDataset
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

from model import SimpleViT

Using device: cuda:3


In [2]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda:3


In [5]:
# get run
api = wandb.Api()
run_id = "yc8zb5vo"
run = api.run("/depren/thesis_official_runs/runs/" + run_id)

# load model
conf = run.config

model = SimpleViT(dim=conf['model_dimension'],
                  device=device,
                  mlp_dim=conf['mlp_dimension'],
                  image_size=(conf['image_size'], conf['image_size']),
                  image_patch_size=(conf['patch_size'], conf['patch_size']),
                  frame_patch_size=conf['patch_depth'],
                  frames=conf['n_prev'],
                  depth=conf['model_depth'],
                  heads=conf['heads'],
                  )

model.load_state_dict(torch.load(conf['save_name']))

data_config = run.config['dataset']
folders = TrajDataset.conf_to_folders(data_config)
size = f"{conf['image_size']}_{conf['image_size']}_{conf['block_size']}"
data_folders = ["/waldo/walban/student_datasets/arfranck/SDD/scenes/" + folder + size for folder in folders]


props = [conf['train_prop'], conf['val_prop'], conf['test_prop']]
n_prev = conf['n_prev']
n_next = conf['n_next']
img_step = conf['img_step']
train_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=0)
val_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=1)
test_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=2)

batch_size=conf['batch_size']
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)



opening track 502 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 526 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 510 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 331 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 537 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 691 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 106 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 24 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 631 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 750 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3

opening track 302 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 596 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 78 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 175 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 683 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 405 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 424 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 669 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 850 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 650 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3

opening track 194 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 615 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 387 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 696 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 412 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 783 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 455 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 808 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 799 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 119 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video

opening track 551 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 573 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 777 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 694 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 784 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 767 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 738 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 61 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 415 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 616 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3

opening track 484 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 152 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 207 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 770 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 533 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 660 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 260 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 496 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 63 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 359 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3

opening track 410 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 538 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 122 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 483 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 384 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 318 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 209 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 287 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 569 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 182 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video

opening track 478 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 154 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 284 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 419 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 342 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 4 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 816 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 335 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 713 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 504 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/

opening track 679 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 171 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 243 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 853 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 15 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 554 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 51 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 838 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 564 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 454 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/

opening track 563 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 487 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 132 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 129 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 546 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 70 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 498 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 87 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 809 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 752 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/

opening track 147 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 470 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 659 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 406 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 512 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 672 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 218 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 323 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 736 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 581 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video

opening track 782 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 221 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 160 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 837 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 349 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 856 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 715 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 164 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 826 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video3/64_64_4
opening track 778 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video

In [None]:
criterion = torch.nn.MSELoss()

for id_b2, batch_test in enumerate(test_loader):


    with torch.no_grad():
        print(f"Test Batch {id_b2}")
        model.eval()

        X_test = batch_test["src"]
        Y_test = batch_test["tgt"]


        X_coords = batch_test["coords"]

        print(X_coords.shape)
        future = None

        for k in range(8):
            pred,output = model(X_test.to(device),future,train=False)
            future = output

        print(criterion(pred[0],Y_test[0].to(device)).item())
        prev = [(k[0]*64,k[1]*64) for k in X_coords[0]]
        points = [(k[0]*64,k[1]*64) for k in Y_test[0]]
        points2 = [(k[0]*64,k[1]*64) for k in pred[0].cpu().detach().numpy()]

        prev_x,prev_y = zip(*prev)
        x,y = zip(*points)
        x2,y2 = zip(*points2)
        print(len(points))

        plt.scatter([k for k in prev_x],[k for k in prev_y], label="Prev")
        plt.scatter([k for k in x],[k for k in y], label="Truth")
        plt.scatter([k for k in x2],[k for k in y2],label="Prediction",color=(["blue" for k in range(7)] + ["black"]))
        #plt.imread("reference.jpg")
        plt.legend()
        plt.xlim(0,64)
        plt.ylim(0,64)
        #plt.imshow(plt.imread("reference.jpg"))

        plt.savefig("/waldo/walban/student_datasets/arfranck/SDD/plots/test_3.pdf")
        plt.show()


        break