# Initialization
* Directory set
* Library import
* Dataset download



# New Section

In [None]:
from google.colab import drive
import sys
#you can save your own drive project path here as a comment
#'/content/gdrive/MyDrive/TUDelft/1.3/DeepLearning/DeepLearningProject'
#'/content/gdrive/MyDrive/DeepLearningProject'
drive.mount('/content/gdrive')
sys.path.append('/content/gdrive/MyDrive/TUDelft/1.3/DeepLearning/DeepLearningProject')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
#cd /content/gdrive/MyDrive/DeepLearningProject

/content/gdrive/.shortcut-targets-by-id/1JFdBnQ-FX1S6xSzRbNQ95iJYaDMD5y9Z/DeepLearningProject


In [None]:
cd /content/gdrive/MyDrive/TUDelft/1.3/DeepLearning/DeepLearningProject/

[Errno 2] No such file or directory: '/content/gdrive/MyDrive/TUDelft/1.3/DeepLearning/DeepLearningProject/'
/content/gdrive/.shortcut-targets-by-id/1JFdBnQ-FX1S6xSzRbNQ95iJYaDMD5y9Z/DeepLearningProject


In [None]:
ls

 2003.12039.pdf  'Project Instructions.gdoc'  'RAFT: Reproduction Plan.gdoc'
 [0m[01;34mdata[0m/            [01;34m__pycache__[0m/                 ReproducibilityCode
 encoders.py     'Raft Architecture.drawio'    TestCopyReproducibilityCode
 [01;34mmodels[0m/         'Raft Architecture.png'


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision.io import read_image
from torchvision.datasets import CIFAR10, Sintel, FlyingChairs
from torchvision import transforms
import torchvision.transforms as T
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torch.cuda.amp import GradScaler

In [None]:
# Transformations applied on each image => only make them a tensor
DATASET_PATH = "./data"

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,),(0.5,))])

In [None]:
class LoadFlow(transforms.ToTensor):
    def __call__(self, img1, img2, flow, valid_flow_mask):
        return (T.functional.to_tensor(img1), T.functional.to_tensor(img2), flow, valid_flow_mask)

In [None]:
# Loading the training dataset. We need to split it into a training and validation part
#train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)
train_dataset = Sintel(root=DATASET_PATH, split = 'train', pass_name = 'clean', transforms = LoadFlow())
#test_dataset = Sintel(root=DATASET_PATH, split = 'test', pass_name = 'clean', transforms = LoadFlow())
#train_dataset = FlyingChairs(root=DATASET_PATH, split = 'test', transforms = LoadFlow())

train_loader = data.DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
#test_loader = data.DataLoader(test_dataset, batch_size=2, shuffle=False, drop_last=True, pin_memory=True, num_workers=2)

# Encoders



*   ResidualBlock
*   Encoder
*   Upscaler



In [None]:
class ResidualUnits(nn.Module):
    
  def __init__(self,
                num_input_channels : int,
                num_output_channels : int,
                stride : int,
                num_groups : int = 16,
                norm : object = nn.InstanceNorm2d):
      super().__init__()

      self.relu = nn.ReLU(inplace=True)
      # self.norm = norm(num_output_channels)

      self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=num_output_channels)
      self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=num_output_channels)
      self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=num_output_channels)

      self.stride = stride

      self.downscale = nn.Sequential(nn.Conv2d(num_input_channels, num_output_channels, 1, stride=stride),self.norm3)

      self.net1 = nn.Conv2d(num_input_channels, num_output_channels, 3, stride=stride, padding=1)
      self.net2 = nn.Conv2d(num_output_channels, num_output_channels, 3, padding=1)
        

  def forward(self, x):
      out = self.net1(x)
      out = self.norm1(out)
      out = self.relu(out)

      out = self.net2(out)
      out = self.norm2(out)
      out = self.relu(out)

      if self.stride != 1:
        x = self.downscale(x)

      out = self.relu(x+out)
      
      return out

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int = 3,
                 base_channel_size : int = 32,
                 num_groups : int = 16,
                 num_output_channels : int = 128,
                 norm : object = nn.InstanceNorm2d):
        
      super().__init__()
      
      self.relu = nn.ReLU(inplace=True)
      channels = base_channel_size
      # self.norm = norm(base_channel_size)
      
      self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels)

      self.init_convolution = nn.Conv2d(num_input_channels, channels, 7, stride=2, padding=3)
      residual_half = nn.Sequential(ResidualUnits(channels, channels, stride=1), ResidualUnits(channels, channels, stride=1))
      residual_quarter = nn.Sequential(ResidualUnits(channels, 2*channels, stride=2), ResidualUnits(2*channels, 2*channels, stride=1))
      residual_eigth = nn.Sequential(ResidualUnits(2*channels, 3*channels, stride=2), ResidualUnits(3*channels, 3*channels, stride=1))
      self.final_convolution = nn.Conv2d(3*channels, num_output_channels, 1)

      self.net = nn.Sequential(
        residual_half,
        residual_quarter,
        residual_eigth
      )

    def forward(self, x):
      x = self.init_convolution(x)
      x = self.norm(x)
      x = self.relu(x)
      x = self.net(x)
      
      return self.final_convolution(x)

In [None]:
class UnfoldDecoder(nn.Module):

  def __init__(self,
                num_input_channels : int,
                base_channel_size : int = 32,
                act_fn : object = nn.ReLU):
      
    super().__init__()

    c_hid = base_channel_size
    self.output = num_input_channels

    self.mask = nn.Sequential(
      nn.Conv2d(96, 4*c_hid, 3, padding=1),
      act_fn(),
      nn.Conv2d(4*c_hid, 8*8*9, 1)
    )

    self.softmax = nn.Softmax(dim=4)

  def forward(self, x, flow):
    # print(x.shape)
    B, D, W, H = x.shape # SANDER: careful! the order is [B, D, H, W] actually
    # SANDER: it works... because you are consistent in naming them wrong ^^ :)
    # SANDER: in my comments below I have also adopted your swapped H/W logic...

    mask = self.mask(x)
    """ SANDER: mask is [B, D, W, H]"""

    # print(mask.shape)
    mask = mask.view(B, 1, 9, 8, 8, W, H)
    mask = self.softmax(mask)

    flow = F.unfold(flow, 3, padding=1) # SANDER: [B, 2, 9, 1, 1, W, H] (right?) (FIXED)
    # SANDER: I believe this goes wrong; you are only allowed to take a different view
    # SANDER:   when the order remains the same!
    # SANDER: consider flow to be [B, 2, 9, 1, 1, W, H]
    # SANDER: then viewing [B*2*9*1*1, W, H] from 'flow' would be safe
    # SANDER: so would [B*2, 9*1, 1, W, H]
    # SANDER: but [B*9, 2*1*1, W, H] would **NOT**
    # SANDER: permute first, before taking a different view or reshaping
    
    y = flow.view(B, self.output, 9, 1, 1, W, H)

    y = torch.sum(mask * y, dim=2)

    # 
    return y.reshape(B, -1, 8*W, 8*H)

# Computing Visual Similarity
* Correlation Volume
* Correlation Pyramid
* Correlation Lookup

In [None]:
#Compute correlation volume C between f1 and f2 image features using torch
#correlation volume formed by dot product between all pairs of feature vectors in f1 and f2
#input: f1, f2: feature maps of size BxDxWxH
#output: correlation volume C of size BxWxHxWxH
def correlationVolume(f1, f2):
    # Calculate the tensor product using einsum
    C = torch.einsum('bdij,bdkl->bijkl', f1, f2)
    return C

In [None]:
#Construct a 4-layer pyramid {C1, C2, C3, C4} by pooling the last two dimensions using torch:
# input: C: correlation volume of size BxWxHxWxH
# output: [C_k]: has dimensions: B x W x H x W/2^k x H/2^k for k = 1, 2, 3, 4
def constructPyramid(C):
    #do pooling for each batch
    for i in range(C.size()[0]):
        C1_batch = C[i, :, :, :, :]
        # SANDER: the paper uses avg_pool, not max pool (FIXED)
        C2_batch = torch.nn.functional.avg_pool2d(C1_batch, 2, stride=2)
        C3_batch = torch.nn.functional.avg_pool2d(C2_batch, 2, stride=2)
        C4_batch = torch.nn.functional.avg_pool2d(C3_batch, 2, stride=2)
        #combine the batch results on new batch dimension again
        if i == 0:
            C1 = C1_batch.unsqueeze(0)
            C2 = C2_batch.unsqueeze(0)
            C3 = C3_batch.unsqueeze(0)
            C4 = C4_batch.unsqueeze(0)
        else:
            C1 = torch.cat((C1, C1_batch.unsqueeze(0)), 0)
            C2 = torch.cat((C2, C2_batch.unsqueeze(0)), 0)
            C3 = torch.cat((C3, C3_batch.unsqueeze(0)), 0)
            C4 = torch.cat((C4, C4_batch.unsqueeze(0)), 0)
    return [C1, C2, C3, C4]


In [None]:
#Compute the feature map F using the pyramid P and the flow field estimate fe using torch
#input: P: pyramid of size [C1, C2, C3, C4] where each Ck has dimensions BxWxHxW/2^kxH/2^k
#       fe: flow field estimate of size Bx2xWxH where dim 1 is the x and y components
#output: F: feature map of size Bx4(2r+1)**2xWxH
def createFeatureMap(correlation_pyramid, fe, radius=4):
  B, W, H, _, _ = correlation_pyramid[0].size()
  device = correlation_pyramid[0].device

  feature_maps = []
  # For every level k
  for i in range(len(correlation_pyramid)):
    # 2 Calculate local grid
    corr = correlation_pyramid[i]
    corr = corr.reshape(B*W*H, 1, W//2**i, H//2**i) 
    corr = corr.permute(0, 1, 3, 2) # B*W*Hx1xH/2^ixW/2^i

    dx = torch.arange(-radius, radius+1, device=device)
    dy = torch.arange(-radius, radius+1, device=device)
    delta = torch.stack(torch.meshgrid(dx, dy), axis=-1) # 2x2r+1x2r+1

    centroid_lvl = fe.permute(0, 3, 2, 1).reshape(B*W*H, 1, 1, 2) / 2 ** i # B*W*Hx1x1x2
    delta_lvl = delta.view(1, 2*radius+1, 2*radius+1, 2) # 1x2r+1x2r+1x2
    coords_lvl = centroid_lvl + delta_lvl # B*W*Hx2r+1x2r+1x2

    h_corr, w_corr = corr.size()[-2:] #H//2^i , W//2^i
    xgrid, ygrid = coords_lvl.split([1, 1], dim=-1) # B*W*Hx2r+1x2r+1x1 , B*W*Hx2r+1x2r+1x1
    xgrid = 2 * xgrid / (w_corr - 1) - 1 # B*W*Hx2r+1x2r+1x1
    ygrid = 2 * ygrid / (h_corr - 1) - 1 # B*W*Hx2r+1x2r+1x1

    grid = torch.cat([xgrid, ygrid], dim=-1) # B*W*Hx2r+1x2r+1x2

    # 3 Bilinear interpolation
    bilinear_sample = F.grid_sample(corr, grid, align_corners=True)
    bilinear_sample = bilinear_sample.view(B, H, W, -1) # BxHxWx(2r+1)**2

    feature_maps.append(bilinear_sample)

  return torch.cat(feature_maps, dim=-1).permute(0, 3, 2, 1)

# Iterative Updates
* Initialization (flow field = 0)
* Inputs (correlation features using previous sections)
* Update (GRU cell)
* Flow Prediction (delta flow)
* Upsampling (final high-res flow field)

In [None]:
class UpdateBlock(nn.Module):

  def __init__(self, hidden_dim=96, radius=4):
    super().__init__()
    # self.conv1_flow = nn.Conv2d(324, 256, 1, padding=1)
    self.conv1_flow = nn.Conv2d(2, 64, 7, padding=3)
    self.conv2_flow = nn.Conv2d(64, 32, 3, padding=1)

    # input dim is 4*(2r+1)^2
    self.conv_corr = nn.Conv2d(4*(2*radius+1)**2, 96, 1, padding=0)
    # print("update input dimension: ", 4*(2*radius+1)**2)
    self.conv_final = nn.Conv2d(128, 80, 3, padding=1)

    self.relu = nn.ReLU()

  def forward(self, flow, corr):
    # print("[UB]: flow", flow.size(), " corr: ", corr.size())
    f1 = self.conv1_flow(flow)
    r1 = self.relu(f1)
    f2 = self.conv2_flow(r1)
    flow_output = self.relu(f2)
    # print("UB: flow_output", flow_output.size())

    corr_output = self.relu(self.conv_corr(corr))
    # print("UB: corr_output", corr_output.size())
    correlation_flow = torch.cat([flow_output, corr_output], dim=1)

    return self.conv_final(correlation_flow)

In [None]:
class DeltaFBlock(nn.Module):

  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(96, 128, 3, padding=1)
    self.conv2 = nn.Conv2d(128, 2, 3, padding=1)
    self.relu = nn.ReLU()
  def forward(self, h):
    return self.conv2(self.relu(self.conv1(h)))

In [None]:
#A custom GRU cell using PyTorch
#fully connected layers replaced with following
#z_t = σ(Conv3x3([h_(t−1), x_t], W_z))
#r_t = σ(Conv3x3([h_(t−1), x_t], W_r))
#h˜_t = tanh(Conv3x3([r_t (dot) h_(t−1), x_t], W_h))
#h_t = (1 − z_t) (dot) h_(t−1) + z_t (dot) h˜_t
class GRUCell(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.conv_z = torch.nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, padding=1)
        self.conv_r = torch.nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, padding=1)
        self.conv_h = torch.nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, padding=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()
        self.tanh = torch.nn.Tanh()

        self.mask = nn.Sequential(
          nn.Conv2d(96, 4*hidden_size, 3, padding=1),
          torch.nn.ReLU(inplace=True),
          nn.Conv2d(4*hidden_size, 8*8*9, 1)
        )
        
    def forward(self, x, h):
        # print("[GRU] x: ", x.size(), " h: ", h.size())
        z = self.sigmoid(self.conv_z(torch.cat((h, x), dim=1)))
        # print("[GRU] z: ", z.size())
        r = self.sigmoid(self.conv_r(torch.cat((h, x), dim=1)))
        # print("[GRU] r: ", r.size())
        h_hat = self.tanh(self.conv_h(torch.cat((r * h, x), dim=1)))
        # print("[GRU] h_hat: ", h_hat.size())
        h = (1 - z) * h + z * h_hat
        # print("[GRU] h: ", h.size())
        return h, self.mask(h)
    def init_hidden(self, batch_size, height, width, device):
        return torch.zeros(batch_size, self.hidden_size, height, width, device=device)
    '''
    #Normally shouldnt be necessary as torch does backpropagation automatically
    def backward(self, x, h):
        z = self.sigmoid(self.conv_z(torch.cat((h, x), 1)))
        r = self.sigmoid(self.conv_r(torch.cat((h, x), 1)))
        h_hat = self.tanh(self.conv_h(torch.cat((r * h, x), 1)))
        h = (1 - z) * h + z * h_hat
        return h
    '''

     

In [None]:
class RAFT(nn.Module):
  def __init__(self, hidden_dim=96, context_dim=64):
    super(RAFT, self).__init__()
    self.featureEncoder = Encoder()
    self.contextEncoder = Encoder(num_output_channels=160)

    self.h_dim = hidden_dim
    self.c_dim = context_dim
    self.gru = GRUCell(146, self.h_dim)

    self.update_block = UpdateBlock()
    self.delta_fblock = DeltaFBlock()

  def upsample_flow(self, mask, flow):
    # print(x.shape)
    B, D, H, W = flow.shape
    # print(mask.shape)
    mask = mask.view(B, 1, 9, 8, 8, H, W)
    mask = torch.softmax(mask, dim=2)

    flow = F.unfold(8*flow, 3, padding=1)
    y = flow.view(B, 2, 9, 1, 1, H, W)
    # SANDER: same upsampling mode here? I think it has the same issue! (FIXED)
    y = torch.sum(mask * y, dim=2)
    y = y.permute(0, 1, 4, 2, 5, 3)

    return y.reshape(B, 2, 8*H, 8*W)

  def forward(self, img1, img2):
    iter=12
    #print(f"img1 size = {img1.size()}")
    B, D, H, W = img1.shape
        
    context_features = self.contextEncoder(img1)
    gru_hidden, context = torch.split(context_features, [self.h_dim, self.c_dim], dim=1)
    # print("hidden: ", self.hidden.size())
    # print("context: ", self.context.size())

    # print(f"context size = {context.size()}")
    features1 = self.featureEncoder(img1)
    # print(f"features1 size = {features1.size()}")
    features2 = self.featureEncoder(img2)

    corr = correlationVolume(features1, features2)
    # print(f"corr size = {corr.size()}")

    pyramids = constructPyramid(corr)
    # pyramids = [torch.ones([B, W//8, W//8, W//8, W//8]), torch.ones([B, W//8, W//8, 16, 16]), torch.ones([B, W//8, W//8, 8, 8]), torch.ones([B, W//8, W//8, 4, 4])]
    # print(f"pyramids size = {pyramids[0].size()} {pyramids[1].size()} {pyramids[2].size()} {pyramids[3].size()}")


    flow = torch.zeros(B, 2, H//8, W//8, device=device)
    # SANDER: paper predicts initial GRU state from the context encoder (FIXED)
    #gru_hidden = self.gru.init_hidden(B, H//8, W//8, device=device)

    upscaled_flow_estimates =[]
    # Loop over iteration
    for i in range(iter):
        # Fix feature map
        corr = createFeatureMap(pyramids, flow)

        # print("corr(feature map): ", corr.size());

        corr_features = self.update_block(flow, corr)

        # print("corr_features: ", corr_features.size())
        # print("context: ", context.size())
        # print("flow: ", self.flow.size())

        cat = torch.cat([corr_features, context, flow], dim=1)

        # print("cat: ", cat.size())
        # print("gru_hidden: ", self.gru_hidden.size())

        gru_output, mask = self.gru(cat, gru_hidden)

        # print("gru_output: ", gru_output.size())

        delta_output = self.delta_fblock(gru_output)

        # print("delta_output: ", delta_output.size())

        flow = delta_output + flow
        upsampled_flow = self.upsample_flow(mask, flow)

        upscaled_flow_estimates.append(upsampled_flow)
        # upscaled_flow_estimates.append(torch.ones([B, 2, width*8, height*8]))
    
    return upscaled_flow_estimates

            

# Training Loop

In [None]:
import time
import matplotlib.pyplot as plt

In [None]:
def get_loss(flow_preds, flow):
  gamma = 0.8
  total_loss = 0
  for i in range(len(flow_preds)):
      weight = gamma**(len(flow_preds)-i-1) # SANDER: careful; fails for iters != 12 (FIXED)
      iter_loss = (flow_preds[i] - flow).abs()
      total_loss += weight*iter_loss.mean()
  
  return total_loss

In [None]:
def train_model(raft):
  # train_features, train_labels = next(iter(train_loader))

  # train_features = torch.ones(5, 3, 256, 256)

  # torch.autograd.set_detect_anomaly(True) # SANDER: careful; a lot slower, make sure to remove for training runs (FIXED)

  #print(train_features.shape)

  
  raft.to(device)
  raft.train()

  # loss_fn = torch.nn.MSELoss()

  scaler = GradScaler()
  optimizer= torch.optim.Adam(raft.parameters(), lr=0.0001)

  total_epoch = 2
  batch_print = 20
  for epoch in range(total_epoch):
    epoch_start = time.time()

    batch_no = 1
    total_batch = len(train_loader)
    # raft(train_features, train_features)
    for image_batch in train_loader:
        if batch_no % batch_print == 1:
            print('Batch No:', batch_no, 'out of', total_batch)
            batch_start = time.time()

        optimizer.zero_grad()
        image_batch1, image_batch2, flow = [x.to(device) for x in image_batch]
        
        B, D, H, W = image_batch1.shape
            
        # Move tensor to the proper device
        # SANDER: using bilinear interpolation on flows usually isn't the best idea; pad the images instead to ensure (size % 8 == 0) (FIXED)

        height = int((H+63) // 64)
        width = int((W+63) // 64)
        
        width_padding = (width*64 - W) // 2
        height_padding = (height*64 - H) // 2

        # SANDER: pad, don't resize :) (FIXED)
        image_batch1 = F.pad(image_batch1, pad=(width_padding, width_padding, height_padding, height_padding))
        image_batch2 = F.pad(image_batch2, pad=(width_padding, width_padding, height_padding, height_padding))
        flow = F.pad(flow, pad=(width_padding, width_padding, height_padding, height_padding))

        # Forward pass through network
        upscaled_flow_preds = raft(image_batch1, image_batch2)
        #upscaled_flow_preds = upscaled_flow_preds.detach()
        # Evaluate loss
        loss = get_loss(upscaled_flow_preds, flow)
        #loss = F.mse_loss(upscaled_flow_preds, flow, reduction="none")

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(raft.parameters(), 1.0)
        # Backward pass
        scaler.step(optimizer)
        scaler.update()
        #optimizer.step()
        if batch_no % batch_print == 0:
            batch_end = time.time()
            print('Batch mean time elapsed:', round((batch_end-batch_start)/batch_print, 2), 's loss:', loss)
        batch_no+=1
    epoch_end = time.time()
    print('Epoch time elapsed:', round(epoch_end-epoch_start, 2),'s with Average batch time:', round((epoch_end-epoch_start)/total_batch, 2))
    torch.save(raft, './models/epoch_'+str(epoch+1))

In [None]:
raft = RAFT()
train_model(raft)



Batch No: 1 out of 373
img1 size = torch.Size([2, 3, 448, 1024])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
Batch mean time elapsed: 63.4 s loss: tensor(33.9471, grad_fn=<AddBackward0>)
Batch No: 21 out of 373
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torch.Size([2, 3, 448, 1024])
img1 size = torc

In [None]:
raft_early = torch.load('./models/epoch_1')

In [None]:
# Shows sample image flow and calculated image flow for one sample
raft.eval()
for sample_batch in train_loader:
    image_batch1, image_batch2, flow = [x.to(device) for x in sample_batch]
    # Move tensor to the proper device
    # SANDER: same as before
    flow = transforms.functional.resize(flow, size=(int(image_batch1.size()[2] // 8)*8, int(image_batch1.size()[3] // 8)*8), antialias=False)
    # Forward pass through network
    upscaled_flow_preds = raft(image_batch1, image_batch2)
    upscaled_flow_preds_early = raft_early(image_batch1, image_batch2)
    _, H, W = flow[0].size()
    buffer = torch.zeros(H, W, 1)

    # visualize inputs outputs
    plt.figure()
    plt.title('Images')
    plt.subplot(1, 2, 1)
    plt.imshow(image_batch1[0].cpu().permute(1,2,0))
    plt.subplot(1, 2, 2)
    plt.imshow(image_batch2[0].cpu().permute(1,2,0))

    plt.figure()
    plt.title('Flows')
    plt.subplot(1, 3, 1)
    plt.imshow(torch.cat((buffer,upscaled_flow_preds_early[0][-1].cpu().permute(1,2,0)), dim=-1).detach().numpy())
    plt.subplot(1, 3, 2)
    plt.imshow(torch.cat((buffer,upscaled_flow_preds[0][-1].cpu().permute(1,2,0)), dim=-1).detach().numpy())
    plt.subplot(1, 3, 3)
    plt.imshow(torch.cat((buffer,flow[0].cpu().permute(1,2,0)), dim=-1))
    break




In [None]:
import gc
del raft
raft = None
gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()


In [None]:
!nvidia-smi -L  