In [1]:
from utils import CNNTransformerDataset, CNNTransformerDatasetMulti
import numpy as np
import torch
from vanilla_transformer.network import TrajectoryPredictTransformerV1
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt

In [7]:
#dataset_nums = ['data/DJI_0008', 'data/DJI_0009', 'data/DJI_0010', 'data/DJI_0011', 'data/DJI_0012']
dataset_nums = ['data/DJI_0012']
dataset = CNNTransformerDatasetMulti(dataset_nums, img_transform=transforms.ToTensor())
val_proportion = 0.25
val_size = int(val_proportion * len(dataset))
train_size = len(dataset) - val_size
validation_dataset, train_dataset = torch.utils.data.random_split(dataset, [val_size, train_size], generator=torch.Generator().manual_seed(42))
trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=12)
testloader = DataLoader(validation_dataset, batch_size=32, shuffle=True, num_workers=12)

#dataset_num = '0012'
#dataset = CNNTransformerDataset(f"data/DJI_{dataset_num}", img_transform=transforms.ToTensor())
#val_proportion = 0.25
#val_size = int(val_proportion * len(dataset))
#train_size = len(dataset) - val_size
#validation_dataset, train_dataset = torch.utils.data.random_split(dataset, [val_size, train_size], generator=torch.Generator().manual_seed(42))
#trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=12)
#testloader = DataLoader(validation_dataset, batch_size=32, shuffle=True, num_workers=12)

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [12]:
def train_loop(model, opt, loss_fn, data_loader, device):
    model.train()
    total_loss = 0
    
    for batch in data_loader:
        #X, y = get_random_batch(points.copy(), 4, 6, batch_size)
        #X, y = torch.tensor(X).float().to(device), torch.tensor(y).float().to(device)
        img, X, y_in, y_label = batch
        img = img.to(device).float()
        X = X.to(device).float()
        y_in = y_in.to(device).float()
        y_label = y_label.to(device).float()
        tgt_mask = model.transformer.generate_square_subsequent_mask(y_in.shape[1]).to(device).float()

        # Standard training except we pass in y_input and tgt_mask
        pred = model(img, X, y_in, tgt_mask=tgt_mask)
        # Permute pred to have batch size first again
        loss = loss_fn(pred, y_label)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.detach().item()
        
    return total_loss / len(data_loader)

def validation_loop(model, loss_fn, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            img, X, y_in, y_label = batch
            img = img.to(device).float()
            X = X.to(device).float()
            y_in = y_in.to(device).float()
            y_label = y_label.to(device).float()
            tgt_mask = model.transformer.generate_square_subsequent_mask(y_in.shape[1]).to(device).float()
            pred = model(img, X, y_in, tgt_mask)
            loss = loss_fn(pred, y_label)
            total_loss += loss.detach().item()
    return total_loss / len(dataloader)

def fit(model, opt, loss_fn, train_data_loader, val_data_loader, epochs, print_every=10, device="cuda"):
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    print("Training model")
    for epoch in range(epochs):
        if epoch % print_every == print_every - 1:
            print("-"*25, f"Epoch {epoch + 1}","-"*25)
            train_loss = train_loop(model, opt, loss_fn, train_data_loader, device)
            train_loss_list += [train_loss]
            validation_loss = validation_loop(model, loss_fn, val_data_loader, device)
            validation_loss_list += [validation_loss]
            print(f"Training loss: {train_loss:.4f}")
            print(f"Validation loss: {validation_loss:.4f}")
            print()
        else:
            train_loss = train_loop(model, opt, loss_fn, train_data_loader, device)
            train_loss_list += [train_loss]
            validation_loss = validation_loop(model, loss_fn, val_data_loader, device)
            validation_loss_list += [validation_loss]
    return train_loss_list, validation_loss_list

In [13]:

model = TrajectoryPredictTransformerV1().to(device)
loss_fn = nn.MSELoss()

In [14]:
opt = torch.optim.SGD(model.parameters(), lr=1e-2)
fit(model=model, opt=opt, loss_fn=loss_fn, train_data_loader=trainloader, val_data_loader=testloader, epochs=10, print_every=2)

Training model
------------------------- Epoch 2 -------------------------
Training loss: 2.0372
Validation loss: 1.3257

------------------------- Epoch 4 -------------------------
Training loss: 1.1794
Validation loss: 1.0175

------------------------- Epoch 6 -------------------------
Training loss: 1.0107
Validation loss: 0.7357

------------------------- Epoch 8 -------------------------
Training loss: 0.9360
Validation loss: 0.5722

------------------------- Epoch 10 -------------------------
Training loss: 0.7745
Validation loss: 0.5278



([3.07314021512866,
  2.0371827678754926,
  1.3981338292360306,
  1.1793798794969916,
  1.1078955791890621,
  1.010742605663836,
  0.9237363622523844,
  0.9359869426116347,
  0.8698587561957538,
  0.774512683507055],
 [2.213006019592285,
  1.3256649049845608,
  0.8752166005698118,
  1.017468891360543,
  0.7925079546191476,
  0.7356586781415072,
  0.6827595843510195,
  0.5721793858842417,
  0.5143501636656848,
  0.5278471586379138])

In [None]:
model_state = torch.load('models/CNN_Transformer_03-04-2022_12-25-47.pth')
model.load_state_dict(model_state)
model = TrajectoryPredictTransformerV1().to(device)

In [15]:
img, X, y_in, y_lbl = dataset[700]

In [16]:
y_in, y_lbl

(tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.1600e+00, -1.5123e-03, -1.0000e-04],
         [ 2.2800e+00, -2.9724e-03,  7.0000e-04],
         [ 3.3000e+00,  5.6979e-03,  2.3000e-03],
         [ 4.1700e+00,  4.5637e-03,  4.9000e-03],
         [ 4.9200e+00,  2.3586e-02,  8.3000e-03],
         [ 5.5900e+00,  3.2712e-02,  1.2500e-02],
         [ 6.1701e+00,  5.1956e-02,  1.7000e-02],
         [ 6.6801e+00,  7.1291e-02,  2.1900e-02],
         [ 7.1101e+00,  9.0731e-02,  2.6600e-02]], dtype=torch.float64),
 tensor([[ 1.1600e+00, -1.5123e-03, -1.0000e-04],
         [ 2.2800e+00, -2.9724e-03,  7.0000e-04],
         [ 3.3000e+00,  5.6979e-03,  2.3000e-03],
         [ 4.1700e+00,  4.5637e-03,  4.9000e-03],
         [ 4.9200e+00,  2.3586e-02,  8.3000e-03],
         [ 5.5900e+00,  3.2712e-02,  1.2500e-02],
         [ 6.1701e+00,  5.1956e-02,  1.7000e-02],
         [ 6.6801e+00,  7.1291e-02,  2.1900e-02],
         [ 7.1101e+00,  9.0731e-02,  2.6600e-02],
         [ 7.4701e+00,  1.0

In [17]:
model(img.cuda()[None,:,:,:].float(), X.cuda()[None,:,:].float(), y_in.cuda()[None,:,:].float())

tensor([[[10.1882, -0.0126,  0.2286],
         [11.8836,  0.3826,  0.1827],
         [11.9498,  0.4735,  0.1552],
         [11.9288,  0.5096,  0.1423],
         [11.9095,  0.5255,  0.1353],
         [11.8947,  0.5352,  0.1309],
         [11.8841,  0.5417,  0.1280],
         [11.8765,  0.5471,  0.1257],
         [11.8732,  0.5511,  0.1239],
         [11.8684,  0.5541,  0.1225]]], device='cuda:0',
       grad_fn=<AddBackward0>)