In [None]:
import wandb
import pandas as pd
import torch
from traj_dataset import TrajDataset
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD, Adagrad
from torch.nn import MSELoss
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

from model import SimpleViT

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

api = wandb.Api()

In [None]:
# get run
run_id = "97y0w076"
run = api.run("/depren/thesis_official_runs/runs/" + run_id)

# load model
conf = run.config
locals().update(conf)
model = SimpleViT(dim=model_dimension,
                  device=device,
                  mlp_dim=mlp_dimension,
                  image_size=(image_size, image_size),
                  image_patch_size=(patch_size, patch_size),
                  frame_patch_size=patch_dept,
                  frames=n_prev,
                  depth=model_depth,
                  heads=heads,
                  )

model.load_state_dict(torch.load(conf['save_name']))

data_config = run.config['dataset']
folders = TrajDataset.conf_to_folders(data_config)
size = f"{image_size}_{image_size}_{block_size}"
data_folders = ["/waldo/walban/student_datasets/arfranck/SDD/scenes/" + folder + size for folder in folders]


props = [conf['train_prop'], conf['val_prop'], conf['test_prop']]
train_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=0)
val_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=1)
test_data = TrajDataset(data_folders, n_prev=n_prev, n_next=n_next, img_step=img_step, prop=props, part=2)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [None]:
criterion = torch.nn.MSELoss()

for id_b2, batch_test in enumerate(test_loader):


    with torch.no_grad():
        print(f"Test Batch {id_b2}")
        model.eval()

        X_test = batch_test["src"]
        Y_test = batch_test["tgt"]


        X_coords = batch_test["coords"]

        print(X_coords.shape)
        future = None

        for k in range(8):
            pred,output = model(X_test.to(device),future,train=False)
            future = output

        print(criterion(pred[0],Y_test[0].to(device)).item())
        prev = [(k[0]*64,k[1]*64) for k in X_coords[0]]
        points = [(k[0]*64,k[1]*64) for k in Y_test[0]]
        points2 = [(k[0]*64,k[1]*64) for k in pred[0].cpu().detach().numpy()]

        prev_x,prev_y = zip(*prev)
        x,y = zip(*points)
        x2,y2 = zip(*points2)
        print(len(points))

        plt.scatter([k for k in prev_x],[k for k in prev_y], label="Prev")
        plt.scatter([k for k in x],[k for k in y], label="Truth")
        plt.scatter([k for k in x2],[k for k in y2],label="Prediction",color=(["blue" for k in range(7)] + ["black"]))
        #plt.imread("reference.jpg")
        plt.legend()
        plt.xlim(0,64)
        plt.ylim(0,64)
        #plt.imshow(plt.imread("reference.jpg"))

        plt.savefig("/waldo/walban/student_datasets/arfranck/SDD/plots/test_3.pdf")
        plt.show()


        break