# Motion Prediction UNet Training Pipeline

In [None]:
!pip install py-lz4framed

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append("../")
from models.ImageToVec import UNet

from torch.utils.data import Dataset
from torchvision import transforms

from PIL import Image
from torchvision.transforms.functional import to_pil_image

import os
from tqdm.notebook import tqdm

import lz4framed
import pickle

import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

from utils.misc.flow_viz import plot_vec_field

def load_compressed_tensor(filename):
    retval = None
    with open(filename, mode='rb') as file:
        retval = torch.from_numpy(pickle.loads(lz4framed.decompress(file.read())))
    return retval


img_size = 256

In [2]:
class ImageVectorFlowDataset(Dataset):
    def __init__(self, img_size, path):
        self.img_size = img_size
        self.path = path
        self.imgs = []
        self.vecs = []
        self.resize = T.Resize(self.img_size[0])
        self.load_data()
    
 
    
    def load_data(self):
        print("Loading data...")
        for file in tqdm(os.listdir(self.path)):
            if file.endswith(".pth"):
                vec_field = load_compressed_tensor(self.path+file)
                img = Image.open(f"{self.path+file[:-10]}input.jpg", "r")
                
                _,_,h,w = vec_field.size()
                cut_pixel = abs(w-h)//2
                if w > h :
                  vec_field = vec_field[:,:,:,cut_pixel:-cut_pixel]
                  img = img.crop((cut_pixel, 0, w-cut_pixel, h))
                elif w < h : 
                  vec_field = vec_field[:,:,cut_pixel:-cut_pixel,:]
                  img = img.crop((0, cut_pixel, w, h-cut_pixel))

                vec_field = self.resize(vec_field).squeeze(0)

                img = self.resize(img)
                img = T.ToTensor()(img)
                self.imgs.append(img)
                self.vecs.append(vec_field)
                
        print("Done!")
      
    def __len__(self):
      return len(self.imgs)
    
    def __getitem__(self, idx):
      return self.imgs[idx], self.vecs[idx]
    
    def select(self, idx):
      return ImageVectorFlowDataset(self.imgs[idx], self.vecs[idx], True)

# DOWNLOAD TRAINING DATA FROM https://eulerian.cs.washington.edu/dataset/ 
# PUT THE DATA IN THE "/data/vector_field_prediction/" FOLDER


In [6]:
dataset = ImageVectorFlowDataset((img_size,img_size), "../data/vector_flow_prediction/")
dataset[0][0].shape, dataset[0][1].shape

Loading data...


  0%|          | 0/10 [00:00<?, ?it/s]

Done!


(torch.Size([3, 256, 256]), torch.Size([2, 256, 256]))

In [7]:
train_size = int(0.8 * len(dataset))
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=True)
#dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
# Custom loss function for vector fields
class VectorFieldLoss(nn.Module):
    def __init__(self):
        super(VectorFieldLoss, self).__init__()

    def forward(self, predicted_flows, target_flows):
        # Separate magnitude and direction from predicted flows
        predicted_magnitudes = torch.norm(predicted_flows, dim=1, keepdim=True)
        predicted_directions = predicted_flows / (predicted_magnitudes + 1e-7)  # Add a small constant to avoid division by zero

        # Separate magnitude and direction from target flows
        target_magnitudes = torch.norm(target_flows, dim=1, keepdim=True)
        target_directions = target_flows / (target_magnitudes + 1e-7)  # Add a small constant to avoid division by zero

        # Calculate magnitude loss using Smooth L1 loss
        magnitude_loss = F.smooth_l1_loss(predicted_magnitudes, target_magnitudes)

        # Calculate direction loss, handling zero magnitude vectors
        direction_mask = target_magnitudes > 1e-7  # Mask to filter out zero magnitude vectors
        num_nonzero_magnitudes = torch.sum(direction_mask).item()
        if num_nonzero_magnitudes > 0:
            predicted_directions_nonzero = predicted_directions[direction_mask.expand_as(predicted_directions)]
            target_directions_nonzero = target_directions[direction_mask.expand_as(target_directions)]
            predicted_directions_nonzero = predicted_directions_nonzero.view(-1, 2)
            target_directions_nonzero = target_directions_nonzero.view(-1, 2)
            direction_loss = 1 - F.cosine_similarity(predicted_directions_nonzero, target_directions_nonzero, dim=1).mean()
        else:
            direction_loss = torch.tensor(0.0)

        # Total loss is the sum of magnitude loss and direction loss
        total_loss = magnitude_loss + direction_loss

        return total_loss

# alternative loss function for vector fields
class VectorLoss(nn.Module):
    def __init__(self):
        super(VectorLoss, self).__init__()
        self.cos_sim = torch.nn.CosineSimilarity(dim=1)
        

    def forward(self, output, target):
        
        direction_cos_sim = self.cos_sim(output, target)
        direction_loss = 0.5 - direction_cos_sim/2

        target_norm = torch.norm(target, dim=1)
        output_norm = torch.norm(output, dim=1)
        diff = (target_norm - output_norm)
        magnitude_loss = diff*diff
        
        # Do not take in account the directon of vectors that are almost 0 on the training set
        minimal_threshold = 0.001
        direction_loss = torch.where(target_norm > minimal_threshold, direction_loss, 0)    

        total_loss = 10 * direction_loss + magnitude_loss * (1 - direction_loss) * (1 - direction_loss)

        return torch.mean(total_loss)

In [9]:
#from unet_model import UNet

torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#model = ImageToVectorFlow().to(device)
model = UNet(3,2).to(device)
criterion = VectorFieldLoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 5e-4)

#training loop
i = 0

model.train()

train_losses = []
eval_losses = []

best_loss = float("inf")

In [None]:
for epoch in tqdm(range(200)):
    epoch_loss = 0
    for img, vec in tqdm(train_dataloader):
        transform = T.Compose([T.RandomAffine(0, translate=(0.05, 0.05)), T.RandomHorizontalFlip()])
        image_transform = T.Compose([transform , T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)])
        img = image_transform(img).to(device)
        vec = transform(vec).to(device)
        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, vec)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        i += 1
    train_losses.append(epoch_loss/len(train_dataloader))
    print(f"Epoch {epoch} loss: {epoch_loss/len(train_dataloader)}")
    with torch.no_grad():
        epoch_loss = 0
        for img, vec in tqdm(eval_dataloader):
            img = img.to(device)
            vec = vec.to(device)
            output = model(img)
            loss = criterion(output, vec)
            epoch_loss += loss.item()
        eval_losses.append(epoch_loss/len(eval_dataloader))
        print(f"Epoch {epoch} eval loss: {epoch_loss/len(eval_dataloader)}")
    
    #save model if it is the best regarding eval loss
    if eval_losses[-1] < best_loss:
        best_loss = eval_losses[-1]
        print("New best model!, saving...")
        torch.save(model.state_dict(), "model_bloup.pth")
    
        