In [1]:
import torch
import torchvision
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from multiscaleloss import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torchvision.__version__

'0.12.0'

In [3]:
data = torchvision.datasets.Sintel(".")

In [4]:
class TSintel(torchvision.datasets.Sintel):
    def __init__(self, root):
        super().__init__(root=root)
        
    def __getitem__(self, index):
        img1, img2, flow = super().__getitem__(index)
        img1 = torchvision.transforms.ToTensor()(img1)
        img2 = torchvision.transforms.ToTensor()(img2)
        return img1, img2, flow

In [5]:
Tdata = TSintel(".")

In [6]:
train_size = round(len(Tdata) * 0.8)
test_size = round(len(Tdata) * 0.2)

In [7]:
assert train_size + test_size == len(Tdata)

In [8]:
train_data, test_data = torch.utils.data.random_split(Tdata, [train_size, test_size], generator=torch.Generator().manual_seed(42))

In [9]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=16, shuffle=False)

In [10]:
from model import FlowNetS
from multiscaleloss import multiscaleEPE



In [11]:
train_params = {
    "epochs": 50,
    "lr": 1e-4,
    "weight_decay": 4e-4
}

dataloaders = {
    "train": train_loader,
    "val": test_loader,
}

In [13]:
def train_flownet(dataloaders, train_params):
    train_loader = dataloaders.get("train")
    val_loader = dataloaders.get("val")
    
    epochs = train_params.get("epochs")
    lr = train_params.get("lr")
    weight_decay = train_params.get("weight_decay")
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = FlowNetS()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model = model.to(device)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150, 200], gamma=0.5)

    
    losses = {
        "train": [],
        "val": []
    }
    
    model.train()
    for epoch in range(epochs):
    
        print("Epoch", str(epoch) + ": ", end="")
        train_loss = 0.0
        val_loss = 0.0
        
        model.train()
        for i, (img1, img2, label) in tqdm(enumerate(train_loader)):
            image = torch.cat((img1, img2), dim=1).to(device)
            label = label.to(device)
            train_loss += flownet_batch_train(model, optimizer, image, label)
            
        model.eval()
        with torch.no_grad():
            for i, (img1, img2, target) in enumerate(val_loader):
                image = torch.cat((img1, img2), dim=1).to(device)
                label = target.to(device)
                val_loss += flownet_batch_validate(model, image, label)
                
        val_loss /= len(val_loader)
        train_loss /= len(train_loader)

        print("Train Loss", train_loss, "Val Loss", val_loss)
        losses["train"].append(train_loss)
        losses["val"].append(val_loss)
            
    return model, losses

def flownet_batch_train(model, optimizer, image, label):
    optimizer.zero_grad()
    outputs = model(image)
    h, w = label.size()[-2:]
    outputs = [torch.nn.functional.interpolate(outputs[0], (h,w)), *outputs[1:]]

    loss = multiscaleEPE(outputs, label, sparse=False)
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        train_loss = realEPE(outputs[0], label, sparse=False).item()
    return train_loss

def flownet_batch_validate(model, image, label):
    output = model(image)
    loss = realEPE(output, label, sparse=False)
    val_loss += loss.item()

In [None]:
model, flownet_losses = train_flownet(dataloaders, train_params)

Epoch 0: 

6it [00:11,  1.87s/it]

In [None]:
torch.save(model.state_dict(), "flownet_sintel.pt")

In [None]:
plt.style.use('seaborn')

In [None]:
plt.figure()

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(len(flownet_losses["train"])), flownet_losses["train"], label="Train")
plt.plot(range(len(flownet_losses["val"])), flownet_losses["val"], label="Val")
plt.legend()

In [None]:
i = 160
output = model(torch.cat((Tdata[i][0], Tdata[i][1]), dim=0).unsqueeze(dim=0).to(device))
img_size = Tdata[i][0].shape[1:]
output = torch.nn.functional.interpolate(output, size=img_size, mode="bilinear", align_corners=False).squeeze()

In [None]:
torchvision.transforms.ToPILImage()(Tdata[i][0])

In [None]:
torchvision.transforms.ToPILImage()(torchvision.utils.flow_to_image(output))