In [95]:
from __future__ import print_function, division
import argparse
import os
# import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [96]:
from RAFT_master.core.raft import RAFT

In [97]:
# from RAFT_master import evaluate

In [98]:
# exclude extremly large displacements
MAX_FLOW = 400
# SUM_FREQ = 100
# VAL_FREQ = 5000

def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
    """ Loss function defined over sequence of flow predictions """

    n_predictions = len(flow_preds)    
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    valid = (valid >= 0.5) & (mag < max_flow)

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)
        i_loss = (flow_preds[i] - flow_gt).abs()
        flow_loss += i_weight * (valid[:, None] * i_loss).mean()

    epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
    epe = epe.view(-1)[valid.view(-1)]

    metrics = {
        'epe': epe.mean().item(),
        '1px': (epe < 1).float().mean().item(),
        '3px': (epe < 3).float().mean().item(),
        '5px': (epe < 5).float().mean().item(),
    }

    return flow_loss, metrics

In [None]:
lr =0.00002
num_steps = 100000
batch_size = 6
image_size =[384, 512]
#     parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
#     parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
iters = 12
wdecay = .00005
epsilon = 1e-8
clip = 1.0
dropout = 0.0
gamma = 0.8 # exponential weighting
#     parser.add_argument('--add_noise', action='store_true')

In [None]:
# valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)

In [None]:
from torch.cuda.amp import GradScaler

In [None]:
model = RAFT()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wdecay, eps=epsilon))
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, num_steps+100, pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
model = model.to(device)

total_steps = 0
scaler = GradScaler(enabled=args.mixed_precision
                    
VAL_FREQ = 5000
add_noise = True

total_steps = 0
scaler = GradScaler(enabled=args.mixed_precision)
VAL_FREQ = 5000
add_noise = True

losses = []
for epoch in range(epochs):
    print("Epoch", str(epoch) + ": ", end="")
    epoch_loss = 0.0
    for i_batch, data_blob in enumerate(train_loader):
        optimizer.zero_grad()
        image1, image2, flow, valid = [x.cuda() for x in data_blob]

        if args.add_noise:
            stdv = np.random.uniform(0.0, 5.0)
            image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
            image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)

        flow_predictions = model(image1, image2, iters=iters)            

        loss, metrics = sequence_loss(flow_predictions, flow, valid, gamma)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)                
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        scaler.step(optimizer)
        scheduler.step()
        scaler.update()

#         if total_steps % VAL_FREQ == VAL_FREQ - 1:
#             PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
#             torch.save(model.state_dict(), PATH)

#             results = {}
#             for val_dataset in args.validation:
#                 if val_dataset == 'chairs':
#                     results.update(evaluate.validate_chairs(model.module))
#                 elif val_dataset == 'sintel':
#                     results.update(evaluate.validate_sintel(model.module))
#                 elif val_dataset == 'kitti':
#                     results.update(evaluate.validate_kitti(model.module))

#             logger.write_dict(results)

#             model.train()
#             if args.stage != 'chairs':
#                 model.module.freeze_bn()

        epoch_loss += loss.item()
        print("Loss:", epoch_loss)
        losses.append(epoch_loss)
