In [1]:
from __future__ import print_function, division
import argparse
import os
# import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Sintel Data Handler

In [3]:
torchvision.__version__

'0.12.0'

In [4]:
data = torchvision.datasets.Sintel(".")

In [5]:
class TSintel(torchvision.datasets.Sintel):
    def __init__(self, root):
        super().__init__(root=root)
        
    def __getitem__(self, index):
        img1, img2, flow = super().__getitem__(index)
        img1 = torchvision.transforms.ToTensor()(img1)
        img2 = torchvision.transforms.ToTensor()(img2)
        return img1, img2, flow

In [6]:
Tdata = TSintel(".")

In [7]:
train_size = round(len(Tdata) * 0.8)
test_size = round(len(Tdata) * 0.2)

In [8]:
assert train_size + test_size == len(Tdata)

In [9]:
train_data, test_data = torch.utils.data.random_split(Tdata, [train_size, test_size], generator=torch.Generator().manual_seed(42))

In [10]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=6, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=6, shuffle=False)

In [11]:
from RAFT_master.core.raft import RAFT
# from RAFT_master import evaluate

In [21]:
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=400):
    """ Loss function defined over sequence of flow predictions """

    n_predictions = len(flow_preds)    
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
#     print(valid.shape)
#     print(mag.shape)
#     valid = (valid >= 0.5) & (mag < max_flow)

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)
        i_loss = (flow_preds[i] - flow_gt).abs()
#         flow_loss += i_weight * (valid[:, None] * i_loss).mean()
        flow_loss += i_weight * i_loss.mean()

#     epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
#     epe = epe.view(-1)[valid.view(-1)]

#     metrics = {
#         'epe': epe.mean().item(),
#         '1px': (epe < 1).float().mean().item(),
#         '3px': (epe < 3).float().mean().item(),
#         '5px': (epe < 5).float().mean().item(),
#     }

#     return flow_loss, metrics
    return flow_loss

In [22]:
from torch.cuda.amp import GradScaler

In [29]:
def crop(img1, img2, flow, valid):
    crop_size =  [384, 512]
    min_scale = -0.2 
    max_scale = 0.5
    ht, wd = img1.shape[2:]
    min_scale1 = np.maximum(
        (crop_size[0] + 8) / float(ht), 
        (crop_size[1] + 8) / float(wd))

    scale = 2 ** np.random.uniform(min_scale, max_scale)
    scale_x = scale
    scale_y = scale

    scale_x = np.clip(scale_x, min_scale1, None)
    scale_y = np.clip(scale_y, min_scale1, None)
    y0 = np.random.randint(0, img1.shape[2] - crop_size[0])
    x0 = np.random.randint(0, img1.shape[3] - crop_size[1])

    img1 = img1[:, :, y0:y0+crop_size[0], x0:x0+crop_size[1]]
    img2 = img2[:, :, y0:y0+crop_size[0], x0:x0+crop_size[1]]
    flow = flow[:, :, y0:y0+crop_size[0], x0:x0+crop_size[1]]
    
    return img1, img2, flow, valid

In [30]:
def train_raft(dataloaders, args):
    
    train_loader = dataloaders.get("train")
    val_loader = dataloaders.get("val")

    lr = 0.00002
    num_steps = 100000
    mixed_precision = True
    iters = 12 
    wdecay = 0.00005
    epsilon =1e-8
    clip =1.0
    dropout = 0.0
    gamma =0.8 # exponential weighting
    add_noise = True
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = RAFT(args)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wdecay, eps=epsilon)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, num_steps+100, pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
    model = model.to(device) 

    total_steps = 0
    scaler = GradScaler(enabled=mixed_precision)
    
    add_noise = True
    
    losses = {
        "train": [],
        "val": []
    }   
    
    epochs = 5
    
    for epoch in range(epochs):
    
        print("Epoch", str(epoch) + ": ", end="")
        train_loss = 0.0
        val_loss = 0.0
        
        # training
        model.train()
        for i, (image1, image2, flow) in tqdm(enumerate(train_loader)):
            if torch.cuda.is_available():
                image1, image2, flow = image1.cuda(), image2.cuda(), flow.cuda() # add this line
            optimizer.zero_grad()
            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
#             if add_noise:
#                 stdv = np.random.uniform(0.0, 5.0)
#                 image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
#                 image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
            image1, image2, flow, valid = crop(image1, image2, flow, valid) 
            flow_predictions = model(image1, image2, iters=iters)
            loss = sequence_loss(flow_predictions, flow, valid, gamma)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)                
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            
            scaler.step(optimizer)
            scheduler.step()
            scaler.update()           
            train_loss += loss
                  
        # validation
        model.eval()
        with torch.no_grad():
            for i, (image1, image2, flow) in enumerate(val_loader):
                if torch.cuda.is_available():
                    image1, image2, flow = image1.cuda(), image2.cuda(), flow.cuda() # add this line
                valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
                flow_predictions = model(image1, image2, iters=iters)            
                loss = sequence_loss(flow_predictions, flow, valid, gamma)     
                val_loss += loss
       
        val_loss /= len(val_loader)
        train_loss /= len(train_loader)

        print("Train Loss", train_loss, "Val Loss", val_loss)
        losses["train"].append(train_loss)
        losses["val"].append(val_loss)

    return model, loss

In [31]:
from argparse import Namespace
dataloaders = {
    "train": train_loader,
    "val": test_loader,
}
args = Namespace(add_noise=False, batch_size=6, clip=1.0, dropout=0.0, epsilon=1e-08, gamma=0.8, gpus=[0, 1], image_size=[384, 512], iters=12, lr=2e-05, mixed_precision=False, name='raft', num_steps=100000, restore_ckpt=None, small=False, stage=None, validation=None, wdecay=5e-05)
train_raft(dataloaders, args)

Epoch 0: 

139it [05:07,  2.21s/it]


RuntimeError: grid_sampler(): expected grid and input to have same batch size, but got input with sizes [42240, 1, 55, 128] and grid with sizes [41472, 9, 9, 2]

In [None]:
torch.save(model.state_dict(), "raft_sintel.pt")

In [None]:
plt.style.use('seaborn')

In [None]:
plt.figure()

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(len(flownet_losses["train"])), flownet_losses["train"], label="Train")
plt.plot(range(len(flownet_losses["val"])), flownet_losses["val"], label="Val")
plt.legend()

In [None]:
i = 160
output = model(torch.cat((Tdata[i][0], Tdata[i][1]), dim=0).unsqueeze(dim=0).to(device='cuda'))
img_size = Tdata[i][0].shape[1:]
output = torch.nn.functional.interpolate(output, size=img_size, mode="bilinear", align_corners=False).squeeze()

In [None]:
torchvision.transforms.ToPILImage()(Tdata[i][0])

In [None]:
torchvision.transforms.ToPILImage()(torchvision.utils.flow_to_image(output))