<a href="https://colab.research.google.com/github/Jeremy26/video_analysis_course/blob/main/FlowNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FlowNet

In this project, we're going to build a FlowNet algorithm with PyTorch! The idea is simple, given two images, output the optical flow!
<p>

![flownet](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ_p_REZwjQ1YqfV51j8vQ1qJodRUDRI8Dd7tPuwbWW-tWUQBhKibGi3Bq1ox6SNp5k2ts&usqp=CAU)

Let's begin with some synchronization and imports!

In [None]:
!wget https://thinkautonomous-flownet.s3.eu-west-3.amazonaws.com/flownet-data.zip && unzip flownet-data.zip && rm flownet-data.zip
!mkdir output
!ls

Imports

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pickle
from google.colab.patches import cv2_imshow
import glob

In [None]:
from __future__ import division
import os.path
import torch.utils.data as data
import os
from imageio import imread
import numbers
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torch.nn.functional as F
import shutil
import torch.nn as nn
from torch.nn.init import kaiming_normal_, constant_
from torch.utils.tensorboard import SummaryWriter
import time
from tqdm import tqdm

## Data

There are a few Optical Flow Datasets we can use:

*   Flying Chairs
*   Scene Flow (KITTI)
*   Middleburry
*   MPI Sintel
*   Kinetics

For the purpose of this course, we'll use KITTI as it's the closest to autonomous driving. But if you want to use another one, you do you and I do me. Just a note, you'll need to read the flownet paper as some architectures and techniques work best on some datasets

In [None]:
images_dataset = sorted(glob.glob("dataset/images_2/*.png"))
labels_dataset = sorted(glob.glob("dataset/flow_occ/*.png"))

print(len(images_dataset))
print(len(labels_dataset))

### 1) Visualize Data & Build a Dataset

Here's what we want:
*   **Input:** A pair of images (t0 and t1)
*   **Labels:** The Flow Map


In [None]:
def bgr2rgb(image):
    """
    Convert BGR TO RGB
    """
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
"""
Make a Dataset
"""

images = []
for flow_map in labels_dataset:
    root_filename = flow_map[-13:-7]
    img1 = os.path.join("dataset/images_2/", root_filename+'_10.png')
    img2 = os.path.join("dataset/images_2/", root_filename+'_11.png')
    images.append([[img1, img2], flow_map])

In [None]:
def load_flow_from_png(png_path):
    '''
    This is used to read flow label images from the KITTI Dataset
    '''
    flo_file = cv2.imread(png_path, -1) # The Image is a 16 Bit Image. We must read it with OpenCV and the flag cv2.IMREAD_UNCHANGED (-1)
    flo_img = flo_file[:,:,2:0:-1].astype(np.float32)

    # See the README File in the KITTI DEVKIT AND THE FLOW READER FUNCTIONS
    invalid = (flo_file[:,:,0] == 0)
    flo_img = flo_img - 32768
    flo_img = flo_img / 64

    # Valid and Small Flow = 1e-10
    flo_img[np.abs(flo_img) < 1e-10] = 1e-10

    # Invalid Flow = 0
    flo_img[invalid, :] = 0

    return flo_img

In [None]:
### TEST
idx = 100
flo_path = images[idx][1]
#yuv = load_flow_from_png(flo_path)
yuv = cv2.imread(flo_path, -1)
#rgb_map = cv2.cvtColor(yuv, cv2.COLOR_YCR_CB2RGB) 
rgb_map = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB) 

In [None]:
"""
Visualize the Data
"""
idx = 100

image_t0 = bgr2rgb(cv2.imread(images[idx][0][0]))
image_t1 = bgr2rgb(cv2.imread(images[idx][0][1]))

f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30,20))
ax1.imshow(image_t0)
ax1.set_title('Image t0', fontsize=30)
ax2.imshow(image_t1)
ax2.set_title('Image t1', fontsize=30)
ax3.imshow(rgb_map)
ax3.set_title("Flow (label)", fontsize=30)

Now, we're going to do the same thing, but we'll split the images and laels into Training and Testing.

In [None]:
def train_test_split(images, default_split=0.8):
    """
    Transforms the Dataset into Train/Test
    """
    split_values = np.random.uniform(0,1,len(images)) < default_split # Randomly decides if an image is train or test
    train_samples = [sample for sample, split in zip(images, split_values) if split]
    test_samples = [sample for sample, split in zip(images, split_values) if not split]
    return train_samples, test_samples

In [None]:
## Send our Loaded Images and Labels into the train/test split
train_samples, test_samples = train_test_split(images)

In [None]:
print(len(train_samples))
print(len(test_samples))

### 2) Load the Images

So far, we have two variables; train_samples and test_samples. The problem is that both only contain URLs for our data; and not the actual data (above is just separate visualization).<p>

We'll have to create a function to load the images based on the paths, and one to load the optical flow images.

In [None]:
def KITTI_loader(root,path_imgs, path_flo):
    """
    Returns the Loaded Images in RGB, and the Loaded Optical Flow Labels
    """
    imgs = [os.path.join(root,path) for path in path_imgs]
    flo = os.path.join(root,path_flo)
    return [cv2.imread(img)[:,:,::-1].astype(np.float32) for img in imgs],load_flow_from_png(flo)

### 3) Build a Dataset Class & Augment

In [None]:
import torch.utils.data as data
import flow_transforms

div_flow = 20 #Value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results.

"""
Div Flow is a factor by which we divide the output (thus >=1). It makes training more stable to deal with low numbers than big ones.
"""

# Define transforms for the Inputs
input_transform = transforms.Compose([flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0,0,0], std=[255,255,255]), transforms.Normalize(mean=[0.45,0.432,0.411], std=[1,1,1])])
"""
In these transforms, we essentially normalize the dataset around the values. In the Flying Chair Dataset, the mean is [0.45,0.432,0.411].
https://github.com/ClementPinard/FlowNetPytorch/issues/101#issuecomment-805222823
"""

# Define transforms for the Labels
target_transform = transforms.Compose([flow_transforms.ArrayToTensor(),transforms.Normalize(mean=[0,0],std=[div_flow,div_flow])])

# Data Augmentation
co_transform = flow_transforms.Compose([flow_transforms.RandomCrop((320,448)), flow_transforms.RandomVerticalFlip(),flow_transforms.RandomHorizontalFlip()])

In [None]:
class ListDataset(data.Dataset):
    def __init__(self, path_list, transform=None, target_transform=None, co_transform=None, loader=KITTI_loader):
        self.root = os.getcwd()
        self.path_list = path_list
        self.transform = transform
        self.target_transform = target_transform
        self.co_transform = co_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        In Python, __getitem__ is used to read values from a class. For example; read the transformed input files.
        Instead of calling the function .read(), we use __getitem__ to directly get the value.
        Similarly, __setitem__ can be used to fill values in a class.
        """
        inputs, target = self.path_list[index]
        inputs, target = self.loader(self.root, inputs, target)
        if self.co_transform is not None:
            inputs, target = self.co_transform(inputs, target)
        if self.transform is not None:
            inputs[0] = self.transform(inputs[0])
            inputs[1] = self.transform(inputs[1])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return inputs, target

    def __len__(self):
        return len(self.path_list)

In [None]:
train_dataset = ListDataset(train_samples, input_transform, target_transform, co_transform, loader=KITTI_loader)

test_dataset = ListDataset(test_samples, input_transform,target_transform, flow_transforms.CenterCrop((370,1224)), loader=KITTI_loader)

👉👉👉 The Output of Module 1 are train_dataset and test_dataset

## FlowNet Architecture

FlowNet has two variations:
*   **FlownetS** or Simple, which is a simple version using 2D Convolutions to get to the optical flow computation
*   **FlownetC** or Correlated, which adds a correlation layer and process images separately

![](https://www.programmersought.com/images/62/6e44fdd78581acc06614a38d8eaff3fe.JPEG)

In both, there are two main parts:
*   An **Encoder** Part, learning features
*   A **Refinement** Part, playing the decoder and creating the output Flow Mask.

It looks like we've got some work! Let's go!


### FlowNet Simple Architecture

![](https://miro.medium.com/max/1400/0*XVygX0wF3enVQJLe.)

In [None]:
#Define a Convolution with or without batchnorm and LeakyReLU

def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU(0.1,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.LeakyReLU(0.1,inplace=True)
        )

Here's the refinement block:<p>
![](https://i1.wp.com/syncedreview.com/wp-content/uploads/2017/09/image-14.png?fit=692%2C268&ssl=1)
<p>

It looks like we're gonna need:

*   **Deconvolutions** as an upsampling method.
*   and **Flow Prediction** as a way to get the output.



In [None]:
#Define the last convolution (optical flow map prediction)
def predict_flow(in_planes):
    return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=False) # Note: In the paper, a kernel Size of 3 is written; but their implementation uses 5

In [None]:
#Define a Deconvolution
def deconv(in_planes, out_planes):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1,inplace=True)
    )

In [None]:
#Define a Cropping Operation
def crop_like(input, target):
    if input.size()[2:] == target.size()[2:]:
        return input
    else:
        return input[:, :, :target.size(2), :target.size(3)]

In [None]:
class FlowNetS(nn.Module):
    expansion = 1

    def __init__(self,batchNorm=True):
        super(FlowNetS,self).__init__()

        #ENCODER PART
        self.batchNorm = batchNorm
        self.conv1   = conv(self.batchNorm,   6,   64, kernel_size=7, stride=2)
        self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
        self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
        self.conv3_1 = conv(self.batchNorm, 256,  256)
        self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
        self.conv4_1 = conv(self.batchNorm, 512,  512)
        self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
        self.conv5_1 = conv(self.batchNorm, 512,  512)
        self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
        self.conv6_1 = conv(self.batchNorm,1024, 1024) # Note: This one doesn't exist in the paper, but it does in their implementation

        #REFINEMENT PART
        self.deconv5 = deconv(1024,512)
        self.deconv4 = deconv(1026,256)
        self.deconv3 = deconv(770,128)
        self.deconv2 = deconv(386,64)

        self.predict_flow6 = predict_flow(1024)
        self.predict_flow5 = predict_flow(1026)
        self.predict_flow4 = predict_flow(770)
        self.predict_flow3 = predict_flow(386)
        self.predict_flow2 = predict_flow(194)

        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                # Initialize the Convolutions with "He Initialization" to 0.1 (https://arxiv.org/pdf/1502.01852.pdf)
                kaiming_normal_(m.weight, 0.1)
                if m.bias is not None:
                    # Initialize all bias to 0
                    constant_(m.bias, 0)
            # Initialize the BatchNorm Convolutions with "He Initialization" to 1 (https://arxiv.org/pdf/1502.01852.pdf)
            elif isinstance(m, nn.BatchNorm2d):
                constant_(m.weight, 1)
                constant_(m.bias, 0)

    def forward(self, x):
        #ENCODER
        out_conv2 = self.conv2(self.conv1(x))
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))
        
        #REFINEMENT
        flow6       = self.predict_flow6(out_conv6)
        flow6_up    = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
        out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)

        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
        flow5       = self.predict_flow5(concat5)
        flow5_up    = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
        out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)

        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
        flow4       = self.predict_flow4(concat4)
        flow4_up    = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
        out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)

        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
        flow3       = self.predict_flow3(concat3)
        flow3_up    = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
        out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)

        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
        flow2 = self.predict_flow2(concat2)

        if self.training:
            return flow2,flow3,flow4,flow5,flow6
        else:
            return flow2

    def weight_parameters(self):
        return [param for name, param in self.named_parameters() if 'weight' in name]

    def bias_parameters(self):
        return [param for name, param in self.named_parameters() if 'bias' in name]


Awesome Work! Now, let's define the Flownet S model (with and without batchnorm)

In [None]:
#Define FlowNet S
def flownets(data=None, batchNorm=False):
    """FlowNetS model architecture from the
    "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
    Args:
        data : pretrained weights of the network. will create a new one if not set
    """
    model = FlowNetS(batchNorm=batchNorm)
    if data is not None:
        model.load_state_dict(data['state_dict'])
    return model

In [None]:
checkpoint = torch.load("models/flownets_bn_EPE2.459.pth.tar")
model = flownets(data=checkpoint, batchNorm=True)

#model = flownets()

👉👉👉 The output of Module 2 is this model

## Training
To train a Deep Learning Model, we'll need:

*   Data
*   A Model
*   Parameters
*   A Loss Function


### Data & Model
Define the basic variables, paths to save the training weights, dataloaders, and the model

In [None]:
print('{} samples found, {} train samples and {} test samples '.format(len(test_dataset)+len(train_dataset),
                                                                           len(train_dataset),
                                                                           len(test_dataset)))

In [None]:
arch = "flownetsbn"
solver = "adam" # or sgd
epochs = 400
epoch_size = 0
batch_size = 32
learning_rate = 10e-4
workers = 4
pretrained = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

save_path = '{},{},{}epochs,b{},lr{}'.format(arch, solver, epochs,batch_size,learning_rate)

if not os.path.exists(save_path):
    os.makedirs(save_path)

train_writer = SummaryWriter(os.path.join(save_path,'train'))
test_writer = SummaryWriter(os.path.join(save_path,'test'))
output_writers = []

for i in range(3):
    output_writers.append(SummaryWriter(os.path.join(save_path,'test',str(i))))

In [None]:
#With these basic values, create dataloaders (PyTorch way of handling the data)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, shuffle=True)
val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,num_workers=workers, pin_memory=True, shuffle=False)

### Parameters

In [None]:
bias_decay = 0
weight_decay = 4e-4
momentum = 0.9 # Momentum for SGD - Alpha for Adam
milestones= [100,150,200] # Epochs by which we divide learning rate by 2

param_groups = [{'params': model.bias_parameters(), 'weight_decay': bias_decay},
                {'params': model.weight_parameters(), 'weight_decay': weight_decay}]


if device.type == "cuda":
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True


if solver == 'adam':
    optimizer = torch.optim.Adam(param_groups, learning_rate,betas=(momentum, 0.999)) # In the paper, Adam is used
elif solver == 'sgd':
    optimizer = torch.optim.SGD(param_groups, learning_rate,momentum=momentum)

### Loss Function - End Point Error

Flownet (and most optical flow algorithms) use the end point error (EPE) as a metric for the loss function.
It is simply the euclidean distance between the real value (ground truth) and the predicted one.<p>
EPE = ![](https://latex.codecogs.com/gif.latex?%5Cinline%20%5Cleft%20%5C%7CV_%7Best%7D%20-%20V_%7Bgt%7D%20%5Cright%20%5C%7C)

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __repr__(self):
        return '{:.3f} ({:.3f})'.format(self.val, self.avg)

In [None]:
def EPE(input_flow, target_flow, sparse=False, mean=True):
    EPE_map = torch.norm(target_flow-input_flow,2,1)
    batch_size = EPE_map.size(0)
    if sparse:
        # invalid flow is defined with both flow coordinates to be exactly 0
        mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
        EPE_map = EPE_map[~mask]
    if mean:
        return EPE_map.mean()
    else:
        return EPE_map.sum()/batch_size

In [None]:
def realEPE(output, target, sparse=False):
    b, _, h, w = target.size()
    upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) # used to resize the output
    return EPE(upsampled_output, target, sparse, mean=True)

In [None]:
def sparse_max_pool(input, size):
    '''Downsample the input by considering 0 values as invalid.
    Unfortunately, no generic interpolation mode can resize a sparse map correctly,
    the strategy here is to use max pooling for positive values and "min pooling"
    for negative values, the two results are then summed.
    This technique allows sparsity to be minized, contrary to nearest interpolation,
    which could potentially lose information for isolated data points.'''

    positive = (input > 0).float()
    negative = (input < 0).float()
    output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
    return output


def multiscaleEPE(network_output, target_flow, weights=None, sparse=False):
    def one_scale(output, target, sparse):

        b, _, h, w = output.size()
        if sparse:
            target_scaled = sparse_max_pool(target, (h, w))
        else:
            target_scaled = F.interpolate(target, (h, w), mode='area')
        return EPE(output, target_scaled, sparse, mean=False)
    

    if type(network_output) not in [tuple, list]:
        network_output = [network_output]
    if weights is None:
        weights = [0.005, 0.01, 0.02, 0.08, 0.32]  # as in original article
    assert(len(weights) == len(network_output))

    loss = 0
    for output, weight in zip(network_output, weights):
        loss += weight * one_scale(output, target_flow, sparse)
    return loss

### Train

In [None]:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)

In [None]:
def save_checkpoint(state, is_best, save_path, filename='checkpoint.pth.tar'):
    torch.save(state, os.path.join(save_path,filename))
    if is_best:
        shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))

In [None]:
def train(train_loader, model, optimizer, epoch, train_writer):
    global n_iter, div_flow
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    flow2_EPEs = AverageMeter()

    multiscale_weights = [0.005,0.01,0.02,0.08,0.32] # from output_flow to flow6

    epoch_size = len(train_loader)# if epoch_size == 0 else min(len(train_loader), epoch_size)

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input, target) in enumerate(train_loader):
        # Go through the entire data loader
        data_time.update(time.time() - end)

        target = target.to(device)
        input = torch.cat(input,1).to(device)

        # Forward Pass
        output = model(input)

        # Since Target pooling is not very precise when sparse,
        # take the highest resolution prediction and upsample it instead of downsampling target
        h, w = target.size()[-2:]
        output = [F.interpolate(output[0], (h,w)), *output[1:]]

        # Compute Multiscale EPE (for all predict flows)
        loss = multiscaleEPE(output, target, weights=multiscale_weights, sparse=True)

        # Compute the Output EPE
        flow2_EPE = div_flow * realEPE(output[0], target, sparse=True)

        # Record loss and EPE
        losses.update(loss.item(), target.size(0))
        train_writer.add_scalar('train_loss', loss.item(), n_iter)
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # compute gradient and do optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 2 == 0:
            # Every 2 steps, print the Loss and EPE
            print('Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}'
                  .format(epoch, i, epoch_size, batch_time,
                          data_time, losses, flow2_EPEs))
        n_iter += 1
        if i >= epoch_size:
            break
    #Return the Average Loss and Average EPE on the Training Set
    return losses.avg, flow2_EPEs.avg


In [None]:
def validate(val_loader, model, epoch, output_writers):
    global div_flow
    batch_time = AverageMeter()
    flow2_EPEs = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        #Go through the entire validation loader

        target = target.to(device)
        input = torch.cat(input,1).to(device)

        # Forward Pass
        output = model(input)

        #Compute the EPE
        flow2_EPE = div_flow*realEPE(output, target, sparse=True)
        # record EPE
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i < len(output_writers):  # log first output of first batches
            if epoch == 0:
                mean_values = torch.tensor([0.45,0.432,0.411], dtype=input.dtype).view(3,1,1)
                output_writers[i].add_image('GroundTruth', flow2rgb(div_flow * target[0], max_value=10), 0)
                output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)
                output_writers[i].add_image('Inputs', (input[0,3:].cpu() + mean_values).clamp(0,1), 1)
            output_writers[i].add_image('FlowNet Outputs', flow2rgb(div_flow * output[0], max_value=10), epoch)

        if i % 5 == 0:
            print('Test: [{0}/{1}]\t Time {2}\t EPE {3}'
                  .format(i, len(val_loader), batch_time, flow2_EPEs))

    print(' * EPE {:.3f}'.format(flow2_EPEs.avg))
    # Return Average EPE on Validation Set
    return flow2_EPEs.avg

Finally, let's convert the output flow to RGB values

In [None]:
def flow2rgb(flow_map, max_value):
    """
    https://github.com/ClementPinard/FlowNetPytorch/issues/86
    """
    flow_map_np = flow_map.detach().cpu().numpy()
    _, h, w = flow_map_np.shape
    flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
    rgb_map = np.ones((3,h,w)).astype(np.float32)
    if max_value is not None:
        normalized_flow_map = flow_map_np / max_value
    else:
        normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
    rgb_map[0] += normalized_flow_map[0]
    rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
    rgb_map[2] += normalized_flow_map[1]
    return rgb_map.clip(0,1)

Now, let's train it!

In [None]:
epochs = 200
save_path = '{},{},{}epochs{},b{},lr{}'.format(arch, solver, epochs, ',epochSize'+str(epoch_size) if epoch_size > 0 else '', batch_size, learning_rate)
n_iter = 0
best_EPE = -1

# We'll start from a model pretrained on "Flying Chairs" and finetune it to KITTI

save_path = os.path.join("models",save_path)

print('=> will save everything to {}'.format(save_path))

if not os.path.exists(save_path):
    os.makedirs(save_path)

for epoch in range(0, epochs):
    scheduler.step()

    # Train for one epoch
    train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
    train_writer.add_scalar('mean EPE', train_EPE, epoch)

    # Evaluate on validation set
    with torch.no_grad():
        endpointerror = validate(val_loader, model, epoch, output_writers)
    test_writer.add_scalar('mean EPE', endpointerror, epoch)

    # Store the best EPE
    if best_EPE < 0:
        best_EPE = endpointerror

    is_best = endpointerror < best_EPE
    best_EPE = min(endpointerror, best_EPE)
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': arch,
        'state_dict': model.module.state_dict(),
        'best_EPE': best_EPE,
        'div_flow': div_flow
    }, is_best, save_path)

## Inference

In [None]:
!ls images

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

@torch.no_grad()

def inference():
    output_string = "RGB visualization" # others are "raw" or "both"
    data_dir = "images"
    save_path = "output"

    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
        transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
    ])

    test_files = sorted(glob.glob("images/*.png"))
    img_pairs = list(zip(*[iter(test_files)] * 2))

    # create model
    network_data = torch.load("models/flownetsbn,adam,500epochs,b32,lr0.001/model_best.pth.tar")
    div_flow = network_data['div_flow']
    #network_data = torch.load("models/flownets_bn_EPE2.459.pth.tar")
    #network_data = torch.load("models/flownets_EPE1.951.pth.tar")

    model = flownets(network_data, batchNorm=True).to(device)
    
    model.eval()
    
    cudnn.benchmark = True

    for (img1_file, img2_file) in tqdm(img_pairs):
        print(img1_file)
        print(img2_file)
        #img1 = input_transform(imread("images/000001_10.png"))
        #img2 = input_transform(imread("images/000001_11.png"))
        img1 = input_transform(imread(str(img1_file)))
        img2 = input_transform(imread(str(img2_file)))     
        #cv2_imshow(cv2.cvtColor(cv2.imread("dataset/flow_occ/000000_10.png", -1), cv2.COLOR_YCR_CB2RGB))
        input_var = torch.cat([img1, img2]).unsqueeze(0)

        input_var = input_var.to(device)
        output = model(input_var)
        print(len(output))

        for suffix, flow_output in zip(['flow', 'inv_flow'], output):
            filename = img1_file[:-4]+"flow"
            rgb_flow = flow2rgb(div_flow * flow_output, max_value=None)
            to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
            cv2_imshow(to_save)
            cv2.imwrite("output/"+filename + '.png', to_save)
        
inference()