In [10]:
from utils import CNNTransformerDataset
import numpy as np
import torch
from vanilla_transformer.network import TrajectoryPredictTransformerV1
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt

In [19]:
dataset_num = '0008'
dataset = CNNTransformerDataset(f"data/DJI_{dataset_num}", img_transform=transforms.ToTensor())
val_proportion = 0.25
val_size = int(val_proportion * len(dataset))
train_size = len(dataset) - val_size
validation_dataset, train_dataset = torch.utils.data.random_split(dataset, [val_size, train_size], generator=torch.Generator().manual_seed(42))
trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=12)
testloader = DataLoader(validation_dataset, batch_size=32, shuffle=True, num_workers=12)

In [20]:
dataset.image_features.shape

(1363, 400, 400, 3)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [12]:
def train_loop(model, opt, loss_fn, data_loader, device):
    model.train()
    total_loss = 0
    
    for batch in data_loader:
        #X, y = get_random_batch(points.copy(), 4, 6, batch_size)
        #X, y = torch.tensor(X).float().to(device), torch.tensor(y).float().to(device)
        img, X, y_in, y_label = batch
        img = img.to(device).float()
        X = X.to(device).float()
        y_in = y_in.to(device).float()
        y_label = y_label.to(device).float()
        tgt_mask = model.transformer.generate_square_subsequent_mask(y_in.shape[1]).to(device).float()

        # Standard training except we pass in y_input and tgt_mask
        pred = model(img, X, y_in, tgt_mask=tgt_mask)
        # Permute pred to have batch size first again
        loss = loss_fn(pred, y_label)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.detach().item()
        
    return total_loss / len(data_loader)

def validation_loop(model, loss_fn, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            img, X, y_in, y_label = batch
            img = img.to(device).float()
            X = X.to(device).float()
            y_in = y_in.to(device).float()
            y_label = y_label.to(device).float()
            tgt_mask = model.transformer.generate_square_subsequent_mask(y_in.shape[1]).to(device).float()
            pred = model(img, X, y_in, tgt_mask)
            loss = loss_fn(pred, y_label)
            total_loss += loss.detach().item()
    return total_loss / len(dataloader)

def fit(model, opt, loss_fn, train_data_loader, val_data_loader, epochs, print_every=10, device="cuda"):
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    print("Training model")
    for epoch in range(epochs):
        if epoch % print_every == print_every - 1:
            print("-"*25, f"Epoch {epoch + 1}","-"*25)
            train_loss = train_loop(model, opt, loss_fn, train_data_loader, device)
            train_loss_list += [train_loss]
            validation_loss = validation_loop(model, loss_fn, val_data_loader, device)
            validation_loss_list += [validation_loss]
            print(f"Training loss: {train_loss:.4f}")
            print(f"Validation loss: {validation_loss:.4f}")
            print()
        else:
            train_loss = train_loop(model, opt, loss_fn, train_data_loader, device)
            train_loss_list += [train_loss]
            validation_loss = validation_loop(model, loss_fn, val_data_loader, device)
            validation_loss_list += [validation_loss]
    return train_loss_list, validation_loss_list

In [13]:

model = TrajectoryPredictTransformerV1().to(device)
loss_fn = nn.MSELoss()

In [14]:
opt = torch.optim.SGD(model.parameters(), lr=1e-3)
fit(model=model, opt=opt, loss_fn=loss_fn, train_data_loader=trainloader, val_data_loader=testloader, epochs=20, print_every=2)

Training model
------------------------- Epoch 2 -------------------------
Training loss: 0.2743
Validation loss: 0.1725

------------------------- Epoch 4 -------------------------
Training loss: 0.2087
Validation loss: 0.2181

------------------------- Epoch 6 -------------------------
Training loss: 0.1729
Validation loss: 0.2622

------------------------- Epoch 8 -------------------------
Training loss: 0.1565
Validation loss: 0.0916

------------------------- Epoch 10 -------------------------
Training loss: 0.1272
Validation loss: 0.0471

------------------------- Epoch 12 -------------------------
Training loss: 0.1371
Validation loss: 0.0792

------------------------- Epoch 14 -------------------------
Training loss: 0.1073
Validation loss: 0.1535

------------------------- Epoch 16 -------------------------
Training loss: 0.1024
Validation loss: 0.0465

------------------------- Epoch 18 -------------------------
Training loss: 0.0909
Validation loss: 0.0326

-----------------

([0.31893681585788725,
  0.2742508664727211,
  0.23005104809999466,
  0.2086819216609001,
  0.2042704254388809,
  0.17286338806152343,
  0.17290397435426713,
  0.15649396926164627,
  0.14779044240713118,
  0.12724338844418526,
  0.1322614885866642,
  0.1370787113904953,
  0.1113824002444744,
  0.1073179230093956,
  0.10734471753239631,
  0.10244339779019355,
  0.1000744141638279,
  0.09091647490859031,
  0.08424308374524117,
  0.09091614037752152],
 [0.1872636377811432,
  0.17253383249044418,
  0.19124091044068336,
  0.21805270574986935,
  0.08317366428673267,
  0.2621913515031338,
  0.10272263549268246,
  0.09160170890390873,
  0.09701055753976107,
  0.04709682706743479,
  0.04286977951414883,
  0.07919076457619667,
  0.1446087220683694,
  0.15346295107156038,
  0.02651057788170874,
  0.046494570560753345,
  0.04471725504845381,
  0.032622088212519884,
  0.03011154057458043,
  0.030866208486258984])