In [37]:
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_, constant_
import torch.nn.functional as F
import torchvision.transforms as transforms
import flow_transforms
import datasets
from tensorboardX import SummaryWriter
import pdb
import os
from multiscaleloss import *
import time

# Utility model helpers
conv, predict_flow, deconv, croplike

In [23]:
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU(0.1,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.LeakyReLU(0.1,inplace=True)
        )
    
def predict_flow(in_planes):
    return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=False)

def deconv(in_planes, out_planes):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1,inplace=True)
    )

def crop_like(input, target):
    if input.size()[2:] == target.size()[2:]:
        return input
    else:
        return input[:, :, :target.size(2), :target.size(3)]

# Flow Net Simple Architecture

In [34]:
class FlowNetSimple(nn.Module):
    expansion = 1
    def __init__(self, batchNorm=True):
        super(FlowNetSimple, self).__init__()
        self.batchNorm = batchNorm
        self.conv1   = conv(self.batchNorm,   6,   64, kernel_size=7, stride=2)
        self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
        self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
        self.conv3_1 = conv(self.batchNorm, 256,  256)
        self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
        self.conv4_1 = conv(self.batchNorm, 512,  512)
        self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
        self.conv5_1 = conv(self.batchNorm, 512,  512)
        self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
        self.conv6_1 = conv(self.batchNorm,1024, 1024)
        self.deconv5 = deconv(1024,512)
        self.deconv4 = deconv(1026,256)
        self.deconv3 = deconv(770,128)
        self.deconv2 = deconv(386,64)
        self.predict_flow6 = predict_flow(1024)
        self.predict_flow5 = predict_flow(1026)
        self.predict_flow4 = predict_flow(770)
        self.predict_flow3 = predict_flow(386)
        self.predict_flow2 = predict_flow(194)
        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                kaiming_normal_(m.weight, 0.1)
                if m.bias is not None:
                    constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                constant_(m.weight, 1)
                constant_(m.bias, 0)
    
    def forward(self, x):
        # all the conv layers
        out_conv2 = self.conv2(self.conv1(x))
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))
        
        flow6       = self.predict_flow6(out_conv6)
        flow6_up    = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
        out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)

        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
        flow5       = self.predict_flow5(concat5)
        flow5_up    = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
        out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)

        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
        flow4       = self.predict_flow4(concat4)
        flow4_up    = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
        out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)

        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
        flow3       = self.predict_flow3(concat3)
        flow3_up    = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
        out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)

        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
        flow2 = self.predict_flow2(concat2)

        if self.training:
            return flow2,flow3,flow4,flow5,flow6
        else:
            return flow2
    
    def weight_parameters(self):
        return [param for name, param in self.named_parameters() if 'weight' in name]

    def bias_parameters(self):
        return [param for name, param in self.named_parameters() if 'bias' in name]

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __repr__(self):
        return '{:.3f} ({:.3f})'.format(self.val, self.avg)

def flow2rgb(flow_map, max_value):
    flow_map_np = flow_map.detach().cpu().numpy()
    _, h, w = flow_map_np.shape
    flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
    rgb_map = np.ones((3,h,w)).astype(np.float32)
    if max_value is not None:
        normalized_flow_map = flow_map_np / max_value
    else:
        normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
    rgb_map[0] += normalized_flow_map[0]
    rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
    rgb_map[2] += normalized_flow_map[1]
    return rgb_map.clip(0,1)

## Download Flying Chairs Dataset

In [4]:
# ! wget https://lmb.informatik.uni-freiburg.de/data/FlyingChairs/FlyingChairs.zip

## Load Dataset

In [15]:
dataset_path = "./FlyingChairs_release/data/"

input_transform = transforms.Compose([
    flow_transforms.ArrayToTensor(),
    transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
    transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
])
target_transform = transforms.Compose([
    flow_transforms.ArrayToTensor(),
    transforms.Normalize(mean=[0,0],std=[255,255])
])
co_transform = flow_transforms.Compose([
    flow_transforms.RandomTranslate(10),
    flow_transforms.RandomRotate(10,5),
    flow_transforms.RandomCrop((320,448)),
    flow_transforms.RandomVerticalFlip(),
    flow_transforms.RandomHorizontalFlip()
])
train_set, test_set = datasets.__dict__["flying_chairs"](
    dataset_path,
    transform=input_transform,
    target_transform=target_transform,
    co_transform=co_transform,
    split=0.8
)

In [19]:
print len(train_set), len(test_set)

18398 4474


In [21]:
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=8,
    num_workers=8, pin_memory=True, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    test_set, batch_size=8,
    num_workers=8, pin_memory=True, shuffle=False)

## Initialize the model

In [24]:
model = FlowNetSimple(batchNorm=False)
model = torch.nn.DataParallel(model)


## Define train and eval helper functions

In [41]:
def train(train_loader, model, optimizer, epoch, train_writer):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    flow2_EPEs = AverageMeter()
    
    epoch_size = min(len(train_loader), 1000)
    model.train()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        input = torch.cat(input,1)
        data_time.update(time.time() - end)
        output = model(input)
        loss = multiscaleEPE(output, target, weights=[0.005, 0.01, 0.02, 0.08, 0.32], sparse=False)
        flow2_EPE = 20 * realEPE(output[0], target, sparse=False)
        losses.update(loss.item(), target.size(0))
        train_writer.add_scalar('train_loss', loss.item(), n_iter)
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % 20 == 0:
            print('Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}'
                  .format(epoch, i, epoch_size, batch_time,
                          data_time, losses, flow2_EPEs))
        n_iter += 1
        if i >= epoch_size:
            break

    return losses.avg, flow2_EPEs.avg


def validate(val_loader, model, epoch, output_writers):

    batch_time = AverageMeter()
    flow2_EPEs = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
#         target = target.to(device)
#         input = torch.cat(input,1).to(device)

        # compute output
        output = model(input)
        flow2_EPE = args.div_flow*realEPE(output, target, sparse=args.sparse)
        # record EPE
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i < len(output_writers):  # log first output of first batches
            if epoch == 0:
                mean_values = torch.tensor([0.411,0.432,0.45], dtype=input.dtype).view(3,1,1)
                output_writers[i].add_image('GroundTruth', flow2rgb(args.div_flow * target[0], max_value=10), 0)
                output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)
                output_writers[i].add_image('Inputs', (input[0,3:].cpu() + mean_values).clamp(0,1), 1)
            output_writers[i].add_image('FlowNet Outputs', flow2rgb(args.div_flow * output[0], max_value=10), epoch)

        if i % 20 == 0:
            print('Test: [{0}/{1}]\t Time {2}\t EPE {3}'
                  .format(i, len(val_loader), batch_time, flow2_EPEs))

    print(' * EPE {:.3f}'.format(flow2_EPEs.avg))

    return flow2_EPEs.avg

## Train the model
optimizer: Adam<br>
bias decay: 0<br>
weights decay: 0.0004 <br>
learning rate: 0.0001 <br>
momentum: 0.9 <br>
beta: 0.999 <br>
milestones: [100, 150, 200] <br>
epochs: 300 <br>
epoch_size: 1000 <br>
div_flow: 20 <br>
sparse: False <br>
multiscaleweights: [0.005, 0.01, 0.02, 0.008, 0.32] <br>

In [None]:
param_groups = [{'params': model.module.bias_parameters(), 'weight_decay': 0.0},
                {'params': model.module.weight_parameters(), 'weight_decay': 0.0004}]
optimizer = torch.optim.Adam(param_groups, 0.0001, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150,200], gamma=0.5)

best_EPE = -1
n_iter = 0
train_writer = SummaryWriter(os.path.join("./save/",'train'))
test_writer = SummaryWriter(os.path.join("./save/",'test'))

for epoch in range(300):
    pdb.set_trace()
    scheduler.step()
    # train 1 step
    train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
    train_writer.add_scalar('mean EPE', train_EPE, epoch)
    # eval 
    with torch.no_grad():
        EPE = validate(val_loader, model, epoch, output_writers)
    
    test_writer.add_scalar('mean EPE', EPE, epoch)

    if best_EPE < 0:
        best_EPE = EPE
    is_best = EPE < best_EPE
    best_EPE = min(EPE, best_EPE)
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': "flowNetSimple",
        'state_dict': model.module.state_dict(),
        'best_EPE': best_EPE,
        'div_flow': 20
    }, is_best)


> <ipython-input-42-039067bd6c9f>(13)<module>()
-> scheduler.step()
(Pdb) n
> <ipython-input-42-039067bd6c9f>(15)<module>()
-> train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
(Pdb) n
Epoch: [0][0/1000]	 Time 27.206 (27.206)	 Data 7.185 (7.185)	 Loss 42.536 (42.536)	 EPE 7.206 (7.206)
Epoch: [0][20/1000]	 Time 11.463 (12.185)	 Data 0.035 (0.369)	 Loss 9.026 (16.096)	 EPE 1.962 (2.948)


Process Process-10:
Process Process-11:
Process Process-16:
Process Process-15:
Process Process-12:
Process Process-14:
Process Process-9:
Process Process-13:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/multiprocessing/process.py", line 267, in _bootstrap
  File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/multiprocessing/process.py", line 267, in _bootstrap
  File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/multiprocessing/process.py", line 267, in _bootstrap
  File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/multip

KeyboardInterrupt: None
> <ipython-input-42-039067bd6c9f>(15)<module>()
-> train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
