# Trying to finetune flownet2 on US data

Did not yield good results. Should be further investigated

This Notebook is not part of the Thesis. It was an starting point I wanted to investigate, but I kindly was directed to pursue other areas.

In [None]:
import numpy as np
import cv2
from PIL import Image
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from math import ceil

from utils.preprocessing import preprocessing_flownet, preprocessing_pwc
from utils.plotting import flow2img, overlaySegment, showFlow
from utils.layers import warp, warpImage
from utils.encoding import labelMatrixOneHot, dice_coeff

from models.flownet2_pytorch.flownet2_mph import *
from models.flownet2_pytorch.flownet2_components import *

import warnings
warnings.filterwarnings('ignore')

# Select a GPU for the work
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
available_gpus = [(torch.cuda.device(i),torch.cuda.get_device_name(i)) for i in range(torch.cuda.device_count())]
print(available_gpus)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
imgs = torch.load('/share/data_ultraschall/nicke_ma/data/train_frames.pth')
segs = torch.load('/share/data_ultraschall/nicke_ma/data/train_segs.pth')

#define a training split 
torch.manual_seed(42)
# Now, we prepare our train & test dataset.
train_set = torch.from_numpy(np.random.choice(np.arange(len(imgs)),size=int(len(imgs)*0.95), replace=False))

test_set = torch.arange(len(imgs))
for idx in train_set:
    test_set = test_set[test_set != idx]


print(f"{train_set.shape[0]} train examples")
print(f"{test_set.shape[0]} test examples")

In [None]:
def warp_seg(moving_seg, flow):
    """
    function to warp the segemntation of the teacher and baseline
    
    moving_seg: CxHxW
    flow: size: BxCxHxW
    """
    B, C, H, W = flow.size()
    # mesh grid
    xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
    yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
    xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
    yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    grid = torch.cat((xx, yy), 1).float().to(flow.device)
    
    vgrid = grid + flow

    # scale grid to [-1,1]
    vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
    vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0

    vgrid = vgrid.permute(0, 2, 3, 1)
    warped_seg_grid = nn.functional.grid_sample(moving_seg.float().unsqueeze(0), vgrid)
    return warped_seg_grid

In [None]:
flownet = FlowNet2()
state_dict = torch.load("models/flownet2_pytorch/FlowNet2_checkpoint.pth.tar")
flownet.load_state_dict(state_dict['state_dict'])

In [None]:
# freeze all parameters from the other blocks, except the fusion block
for param in flownet.flownetc.parameters():
    param.requires_grad = False

for param in flownet.flownets_1.parameters():
    param.requires_grad = False

for param in flownet.flownets_2.parameters():
    param.requires_grad = False

for param in flownet.flownets_d.parameters():
    param.requires_grad = False
    

In [None]:
flownet.cuda()

# Before finetuning
Before fine tuning, we need to see the performance of the flownet

In [None]:
# eval Flownet
def eval_flownet(model):
    overall_dice = []
    unwarped_dice = [] 
    scale=4
    for i,idx in enumerate(test_set):

        # Get image and segmentation
        fixed = imgs[idx:idx+1,0,:].unsqueeze(0).float()
        moving = imgs[idx:idx+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[idx:idx+1,0,:].contiguous()
        moving_seg = segs[idx:idx+1,1,:].contiguous()
        
        fixed = F.interpolate(fixed, size=(scale*64,scale*64), mode='bicubic')
        moving = F.interpolate(moving, size=(scale*64,scale*64), mode='bicubic')
        
        fixed_seg = F.interpolate(fixed_seg.unsqueeze(0), size=(scale*64,scale*64), mode='bicubic')
        moving_seg = F.interpolate(moving_seg.unsqueeze(0), size=(scale*64,scale*64), mode='bicubic')
        
        flow_in = preprocessing_flownet(fixed.detach().clone().reshape(scale*64,scale*64,1),moving.clone().reshape(scale*64,scale*64,1)).cuda()
        
        flow_out = flownet(flow_in)
        
        warped_seg = warp_seg(moving_seg.view(1,scale*64,scale*64).cuda(), flow_out).cpu()
        
        d1 = dice_coeff(warped_seg,fixed_seg,3)
        d2 = dice_coeff(moving_seg, fixed_seg, 3)
            
        overall_dice.append(d1.mean())
        unwarped_dice.append(d2.mean())
        
    overall_dice = torch.from_numpy(np.array(overall_dice))
    unwarped_dice = torch.from_numpy(np.array(unwarped_dice))
    
    return overall_dice.mean(), unwarped_dice.mean()
    #print(f"This model has an average Dice of {round(overall_dice.mean().item(), 5)} mit Variance: {round(overall_dice.var().item(), 5)}. The unwarped Mean dice is: {round(unwarped_dice.mean().item(), 5)} with Var {round(unwarped_dice.var().item(),5)}")

In [None]:
print(eval_flownet(flownet))

In [None]:
epochs = 500
lr = 0.00001
# minibatch training
grad_accum = 30

optimizer = torch.optim.Adam(list(flownet.parameters()),lr=lr)

In [None]:
losses = []
acc = []
scale=4
for epoch in tqdm(range(epochs)):
    rnd_train_idx = torch.randperm(train_set.size(0))

    # show all examples to model
    for i, rnd_idx in enumerate(rnd_train_idx):
        tmp_loss = []
        
        p_fix = train_set[rnd_idx]

        # Get image and segmentation
        fixed = imgs[p_fix:p_fix+1,0,:].unsqueeze(0).float()
        moving = imgs[p_fix:p_fix+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[p_fix:p_fix+1,0,:].contiguous()
        moving_seg = segs[p_fix:p_fix+1,1,:].contiguous()
        
        fixed = F.interpolate(fixed, size=(scale*64,scale*64), mode='bicubic')
        moving = F.interpolate(moving, size=(scale*64,scale*64), mode='bicubic')
        
        fixed_seg = F.interpolate(fixed_seg.unsqueeze(0), size=(scale*64,scale*64), mode='bicubic')
        moving_seg = F.interpolate(moving_seg.unsqueeze(0), size=(scale*64,scale*64), mode='bicubic')
        
        flow_in = preprocessing_flownet(fixed.detach().clone().reshape(scale*64,scale*64,1),moving.clone().reshape(scale*64,scale*64,1)).cuda()
        
        flow_out = flownet(flow_in)
        
        warped_seg = warp_seg(moving_seg.view(1,scale*64,scale*64).cuda(), flow_out).cpu()
        warped_seg_onehot = F.one_hot(warped_seg.long(),num_classes=2).float()
        fixed_seg_onehot = F.one_hot(fixed_seg.long(), num_classes=2).float()
        
        loss = torch.sum(torch.pow(warped_seg-fixed_seg,2)).mean()
        loss.backward()
        tmp_loss.append(loss.item())
    
        
    if (epoch+1)%grad_accum == 0:
        # every grad_accum iterations :Make an optimizer step
        optimizer.step()
        optimizer.zero_grad() 
    
    with torch.no_grad():
        d0 = []
        for i,idx in enumerate(test_set):

            # Get image and segmentation
            fixed = imgs[idx:idx+1,0,:].unsqueeze(0).float()
            moving = imgs[idx:idx+1,1,:].unsqueeze(0).float()

            fixed_seg = segs[idx:idx+1,0,:].contiguous()
            moving_seg = segs[idx:idx+1,1,:].contiguous()

            fixed = F.interpolate(fixed, size=(128,128), mode='bicubic')
            moving = F.interpolate(moving, size=(128,128), mode='bicubic')

            fixed_seg = F.interpolate(fixed_seg.unsqueeze(0), size=(128,128), mode='bicubic')
            moving_seg = F.interpolate(moving_seg.unsqueeze(0), size=(128,128), mode='bicubic')

            flow_in = preprocessing_flownet(fixed.detach().clone().reshape(128,128,1),moving.clone().reshape(128,128,1)).cuda()

            flow_out = flownet(flow_in)

            warped_seg = warp_seg(moving_seg.view(1,128,128).cuda(), flow_out).cpu()

            d1 = dice_coeff(warped_seg,fixed_seg,3)
            d0.append(d1.mean())
    acc.append(np.mean(d0))
    losses.append(np.mean(tmp_loss))

In [None]:
plt.plot(np.arange(len(acc)), acc)
plt.savefig('Flownet_finetune_acc_500epochs.png')

In [None]:
plt.plot(np.arange(len(losses)), losses)
plt.savefig('Flownet_finetune_loss_500epochs.png')

In [None]:
eval_flownet(flownet)

In [None]:
torch.save(flownet.state_dict(), "flownet_finetuned_500.pth")

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)