In [1]:
import wandb
import torch
from torch.nn import MSELoss
from traj_dataset import TrajDataset
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader
import PIL
from model import SimpleViT

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)


Using device: cuda:3


In [45]:
api = wandb.Api()
# get run
run_id = "lg0uz2gb"
run = api.run("/depren/thesis_official_runs/runs/" + run_id)

# load model
conf = run.config

model = SimpleViT(dim=conf['model_dimension'],
                  device=device,
                  mlp_dim=conf['mlp_dimension'],
                  image_size=(conf['image_size'], conf['image_size']),
                  image_patch_size=(conf['patch_size'], conf['patch_size']),
                  frame_patch_size=conf['patch_depth'],
                  frames=conf['n_prev'],
                  depth=conf['model_depth'],
                  heads=conf['heads'],
                  )
print(model.device)
model.load_state_dict(torch.load(conf['save_name'],map_location=device))
print(model.device)




cuda:3
cuda:3


In [46]:

data_config = run.config['dataset']
folders = TrajDataset.conf_to_folders(data_config)
size = f"{conf['image_size']}_{conf['image_size']}_{conf['block_size']}"
data_folders = ["/waldo/walban/student_datasets/arfranck/SDD/scenes/" + folder + size for folder in folders]


props = [conf['train_prop'], conf['val_prop'], conf['test_prop']]
n_prev = conf['n_prev']
n_next = conf['n_next']
img_step = conf['img_step']
train_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=0,limit=20)
val_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=1,limit=20)
test_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=2,limit=20)

opening track 44 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 109 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 9 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 144 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 106 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 73 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 172 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 52 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 211 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 113 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 2

opening track 185 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 118 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 354 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 152 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 302 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 366 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 167 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 49 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 194 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 12 from /waldo/walban/student_datasets/arfranck/SDD/scenes/hyang/video4/64_64_4
opening track 37 from /waldo/walban/student_datasets

opening track 105 from /waldo/walban/student_datasets/arfranck/SDD/scenes/coupa/video3/64_64_4
opening track 117 from /waldo/walban/student_datasets/arfranck/SDD/scenes/coupa/video3/64_64_4
opening track 102 from /waldo/walban/student_datasets/arfranck/SDD/scenes/coupa/video3/64_64_4
opening track 642 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening track 707 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening track 889 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening track 222 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening track 78 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening track 1141 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening track 365 from /waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4
opening 

opening track 156 from /waldo/walban/student_datasets/arfranck/SDD/scenes/nexus/video2/64_64_4
opening track 44 from /waldo/walban/student_datasets/arfranck/SDD/scenes/nexus/video2/64_64_4
opening track 22 from /waldo/walban/student_datasets/arfranck/SDD/scenes/quad/video2/64_64_4
opening track 233 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 173 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 26 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 188 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 35 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 57 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 62 from /waldo/walban/student_datasets/arfranck/SDD/scenes/bookstore/video0/64_64_4
opening track 70 from /waldo/

opening track 25 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 24 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 180 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 110 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 123 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 234 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 258 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 201 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 46 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 11 from /waldo/walban/student_datasets/arfranck/SDD/scenes/little/video3/64_64_4
opening track 49 from /waldo/walban/student_

In [47]:
batch_size=conf['batch_size']
batch_size = 20
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [None]:
criterion = torch.nn.MSELoss()

for id_b2, batch_test in enumerate(test_loader):
    

    with torch.no_grad():
        print(f"Test Batch {id_b2}")
        model.eval()

        X_test = batch_test["src"]
        Y_test = batch_test["tgt"]


        X_coords = batch_test["coords"]

        print(X_coords.shape)
        future = None

        for k in range(8):
            pred,output = model(X_test.to(device),future,train=False)
            future = output

        print(criterion(pred,Y_test.to(device)).item())
        print(criterion(pred[0],Y_test[0].to(device)).item())
        prev = [(k[0]*64,k[1]*64) for k in X_coords[0]]
        points = [(k[0]*64,k[1]*64) for k in Y_test[0]]
        points2 = [(k[0]*64,k[1]*64) for k in pred[0].cpu().detach().numpy()]

        prev_x,prev_y = zip(*prev)
        x,y = zip(*points)
        x2,y2 = zip(*points2)
        print(len(points))
        img = PIL.Image.open("/waldo/walban/student_datasets/arfranck/SDD/scenes/deathCircle/video1/64_64_4/reference.jpg")
        plt.imshow(img)
        plt.scatter([k for k in prev_x],[k for k in prev_y], label="Prev")
        plt.scatter([k for k in x],[k for k in y], label="Truth")
        plt.scatter([k for k in x2],[k for k in y2],label="Prediction",color=(["blue" for k in range(7)] + ["black"]))
        
        plt.legend()
        #plt.xlim(0,64)
        #plt.ylim(0,64)
        

        plt.savefig("/waldo/walban/student_datasets/arfranck/SDD/plots/test_3.pdf")
        plt.show()




In [None]:
print(len(test_loader))