## Ensemble Exploration
Questions to find answers to:
- What is a good workflow for passing one image through the ensemble
- What is a suitable loss? (Data is needed for that.)
- Deep mutual learning losses vs standard student teacher loss

In [None]:
import torch
import torch.nn.functional as F

import numpy as np
import cv2
from math import ceil
import matplotlib.pyplot as plt

from utils.preprocessing import preprocessing_flownet, preprocessing_pwc
from utils.load_models import load_flownet2, load_pwcnet, init_weights
from utils.plotting import flow2img, overlaySegment, showFlow
from utils.layers import warp, warpImage, correlation_layer, min_convolution, meanfield
from utils.encoding import labelMatrixOneHot, dice_coeff


from models.pdd_net.pdd_student import OBELISK2d, deeds2d

In [None]:
flow2 = load_flownet2().cuda()
pwc = load_pwcnet().cuda()

In [None]:
baseline = cv2.optflow.DualTVL1OpticalFlow_create()

In [None]:
seq = torch.nn.Sequential(torch.nn.Conv2d(1,32,kernel_size=5,stride=2,padding=4,dilation=2),
                          torch.nn.BatchNorm2d(32),
                          torch.nn.PReLU(),
                          torch.nn.Conv2d(32,32,kernel_size=3,stride=1,padding=1,dilation=1),
                          torch.nn.BatchNorm2d(32),
                          torch.nn.PReLU(),
                          torch.nn.Conv2d(32,64,kernel_size=3,stride=2,padding=1,dilation=1),
                          torch.nn.BatchNorm2d(64),
                          torch.nn.PReLU(),
                          torch.nn.Conv2d(64,24,kernel_size=1,stride=1,padding=0,dilation=1),
                          torch.nn.Sigmoid())

In [None]:
net = OBELISK2d()
reg = deeds2d()
init_weights(net)

PWC and Flow2 seem to occupy around 1,5 GB on a graphics card

In [None]:
imgs = torch.load('Data/img.pth')
segs = torch.load('Data/seg.pth')

fix = 1; mov=2
fixed = imgs[fix:fix+1,:,:].float() /255
moving = imgs[mov:mov+1,:,:].float() /255

fixed_seg = segs[fix:fix+1, :,:].float().contiguous()
moving_seg = segs[mov:mov+1,:,:].float().contiguous()

#fixed = F.interpolate(fixed.unsqueeze(0), size=(100,100)).view(1,100,100)
#moving = F.interpolate(moving.unsqueeze(0), size=(100,100)).view(1,100,100)

#fixed_seg = F.interpolate(fixed_seg.unsqueeze(0), size=(100,100)).view(1,100,100)
#moving_seg = F.interpolate(moving_seg.unsqueeze(0), size=(100,100)).view(1,100,100)

C,h,w = fixed.shape

In [None]:
flow_in = preprocessing_flownet(fixed.reshape(h,w,C),moving.reshape(h,w,C)).cuda()

In [None]:
flow_in.shape

In [None]:
flow2_out = flow2(flow_in).cpu()

In [None]:
pwc_in = preprocessing_pwc(fixed.reshape(h,w,C),moving.reshape(h,w,C))

In [None]:
pwc.eval()
pwc_out = pwc(pwc_in.cuda())
pwc_out = pwc_out[0]*20
pwc_out = pwc_out.cpu().data.numpy()

pwc_out = np.swapaxes(np.swapaxes(pwc_out, 0, 1), 1, 2) 

divisor = 64
H_ = int(ceil(h/divisor) * divisor)
W_ = int(ceil(w/divisor) * divisor)

u_ = cv2.resize(pwc_out[:,:,0],(w,h))
v_ = cv2.resize(pwc_out[:,:,1],(w,h))
u_ *= w/ float(W_)
v_ *= h/ float(H_)
pwc_out = np.dstack((u_,v_))

pwc_out.shape

In [None]:
flow2_out = flow2_out.squeeze().cpu().detach().numpy().transpose(1,2,0)
flow2_out.shape

In [None]:
displace_range=11
feat00 = seq(fixed.unsqueeze(0))
feat50 = seq(moving.unsqueeze(0))
ssd_distance = correlation_layer(displace_range, feat50, feat00)
cost_soft, pred_xy = meanfield(ssd_distance, fixed.unsqueeze(0), displace_range, h//4, w//4)

In [None]:
pred_xy.squeeze().detach().view(h,w,2).shape

In [None]:
in1 = fixed.view(h,w,1).detach().numpy().astype(np.float32) 
in2 = moving.view(h,w,1).detach().numpy().astype(np.float32) 

baseline_flow = baseline.calc(in1,in2,None)

In [None]:
fig = plt.figure(figsize=(12,12))
plt.subplot(161)
plt.title("FlowNet2")
plt.imshow(flow2img(flow2_out))

plt.subplot(162)
plt.title("PWC")
plt.imshow(flow2img(pwc_out))

#plt.subplot(163)
#plt.title("PDD-Untrained")
#plt.imshow(flow2img(pred_xy.squeeze().detach().view(h,w,2)))

plt.subplot(164)
plt.title("Baseline")
plt.imshow(flow2img(baseline_flow))

plt.subplot(165)
plt.title('Fixed')
plt.imshow(fixed.squeeze().detach())

plt.subplot(166)
plt.title('Moving')
plt.imshow(moving.squeeze().detach())

So not happy with PWC

BUT! PDD with init_weights, looks quite close to the desired flow. Ish

# Trying to learn on the data
this is basically overfitting on one image with the labels.

In [None]:
net = OBELISK2d()
reg = deeds2d()
init_weights(net)
net.train()
reg.train()

# 20 epochs should be fine for overfitting
epochs = 20
optimizer = torch.optim.Adam(list(net.parameters())+list(reg.parameters()),lr=0.05)

disp_range = 0.25#0.25
displacement_width = 15#15#11#17
shift_xy = F.affine_grid(disp_range*torch.eye(2,3).unsqueeze(0),(1,1,displacement_width,displacement_width)).view(1,1,-1,2) 
grid_xy = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,grid_size,grid_size)).view(1,-1,1,2)

# Testing with only one image to overfit on the flow and check if it works
# verfitting with upscale
diffs = []

# Genrate the flow we want to approximate

for i in range(epochs):
    optimizer.zero_grad()
        
    #teacher_output = flow2(flow_in).cpu()

    # Generate the pdd-output
    fixed.requires_grad = True
    moving.requires_grad=True
    
    feat00 = net(fixed.unsqueeze(0))
    feat50 = net(moving.unsqueeze(0))
    
    soft_cost, pred_xy = reg(feat00,feat50)
    pred_xy = pred_xy.view(1,grid_size,grid_size,2)
    diffloss = 1.5*((pred_xy[0,:,1:,:]-pred_xy[0,:,:-1,:])**2).mean()+\
            1.5*((pred_xy[0,1:,:,:]-pred_xy[0,:-1,:,:])**2).mean()+\
            1.5*((pred_xy[0,:,:,1:]-pred_xy[0,:,:,:-1])**2).mean()
    
    # Not sure what this does. 
    nonlocal_label = (F.grid_sample(labelMatrixOneHot(moving_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    fixed_label = F.grid_sample(labelMatrixOneHot(fixed_seg,9).float(),grid_xy,padding_mode='border')#.long().squeeze(1)

    orig_labelloss = ((fixed_label-nonlocal_label)**2).mean()
    diff = orig_labelloss + diffloss
    diff.backward()
    diffs.append(diff.item())
    optimizer.step()

In [None]:
plt.plot(np.arange(epochs), diffs)

In [None]:
net.eval()
reg.eval()

feat00 = net(fixed.unsqueeze(0))
feat50 = net(moving.unsqueeze(0))

cost_soft, pred_xy = reg(feat00,feat50)
teacher_out = flow2(flow_in).cpu()
teacher_grid = F.grid_sample(teacher_out, grid_xy)

dense_flow_fit = F.interpolate(pred_xy.transpose(0,1).view(1,2,grid_size,grid_size),size=(h,w),mode='bicubic')
#apply and evaluate transformation
identity = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,h,w),align_corners=False)
warped_seg = F.grid_sample(fixed_seg.float().unsqueeze(1),identity+dense_flow_fit.permute(0,2,3,1),mode='nearest',align_corners=False)

warped_baseline = warpImage(fixed_seg.unsqueeze(0).cpu(), torch.from_numpy(baseline_flow).unsqueeze(0).view(1,2,h,w).cpu())

#warped_baseline = F.grid_sample(moving_seg.float().unsqueeze(1),torch.from_numpy(baseline_flow).unsqueeze(0),mode='nearest',align_corners=False)
d1 = dice_coeff(moving_seg,warped_seg.squeeze(),8)
print("Dice between moving and warped: ",d1,d1.mean())
d2 = dice_coeff(moving_seg,warped_baseline.squeeze(),8)
print("Dice between moving and baseline: ",d2, d2.mean())
d3 = dice_coeff(fixed_seg,moving_seg,8)
print("Dice between fixed and moving: ",d3, d3.mean())

plt.subplot(131)
plt.title("PDD-Output")
plt.imshow(showFlow(pred_xy.transpose(0,1).reshape(2,grid_size,grid_size).detach()))

plt.subplot(132)
plt.title('Fixed')
plt.imshow(fixed.detach().squeeze())

plt.subplot(133)
plt.title('Moving')
plt.imshow(moving.detach().squeeze())

## Let's get a small pdd_net to learn FlowNet2 warpings

trying to overfit on the wapred data of the flownet2.

In [None]:
net = OBELISK2d()
reg = deeds2d()
init_weights(net)
net.train()
reg.train()

# 20 epochs should be fine for overfitting
epochs = 30
optimizer = torch.optim.Adam(list(net.parameters())+list(reg.parameters()),lr=0.05)

disp_range = 0.25#0.25
displacement_width = 15#15#11#17
shift_xy = F.affine_grid(disp_range*torch.eye(2,3).unsqueeze(0),(1,1,displacement_width,displacement_width)).view(1,1,-1,2) 
grid_xy = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,grid_size,grid_size)).view(1,-1,1,2)

# Testing with only one image to overfit on the flow and check if it works
# verfitting with upscale
diffs = []
for i in range(epochs):
    optimizer.zero_grad()
        
    # Genrate the flow we want to approximate
    teacher_flow = flow2(flow_in).cpu()
    warped_moving_teacher = warpImage(labelMatrixOneHot(fixed_seg,9), teacher_flow)

    # Generate the pdd-output
    fixed.requires_grad = True
    moving.requires_grad=True
    
    feat00 = net(fixed.unsqueeze(0))
    feat50 = net(moving.unsqueeze(0))
    
    soft_cost, pred_xy = reg(feat00,feat50)
    pred_xy = pred_xy.view(1,grid_size,grid_size,2)
    
    diffloss = 1.5*((pred_xy[0,:,1:,:]-pred_xy[0,:,:-1,:])**2).mean()+\
            1.5*((pred_xy[0,1:,:,:]-pred_xy[0,:-1,:,:])**2).mean()+\
            1.5*((pred_xy[0,:,:,1:]-pred_xy[0,:,:,:-1])**2).mean()
    
    warped_seg = (F.grid_sample(labelMatrixOneHot(fixed_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    warped_teacher = F.grid_sample(warped_moving_teacher, grid_xy)
    
    diff =(warped_teacher-warped_seg).sum()
    diff.backward()
    diffs.append(diff.item())
    optimizer.step()

In [None]:
plt.plot(np.arange(epochs), diffs)

In [None]:
net.eval()
reg.eval()

feat00 = net(fixed.unsqueeze(0))
feat50 = net(moving.unsqueeze(0))

cost_soft, pred_xy = reg(feat00,feat50)
teacher_out = flow2(flow_in).cpu()
teacher_grid = F.grid_sample(teacher_out, grid_xy)

dense_flow_fit = F.interpolate(pred_xy.transpose(0,1).view(1,2,grid_size,grid_size),size=(h,w),mode='bicubic')
#apply and evaluate transformation
identity = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,h,w),align_corners=False)
warped_seg = F.grid_sample(fixed_seg.float().unsqueeze(0),identity+dense_flow_fit.permute(0,2,3,1),mode='nearest',align_corners=False)

warped_baseline = warpImage(fixed_seg.unsqueeze(0).cpu(), torch.from_numpy(baseline_flow).unsqueeze(0).view(1,2,h,w).cpu())

#warped_baseline = F.grid_sample(moving_seg.float().unsqueeze(1),torch.from_numpy(baseline_flow).unsqueeze(0),mode='nearest',align_corners=False)
d1 = dice_coeff(moving_seg,warped_seg.squeeze(),8)
print(d1,d1.mean())
d2 = dice_coeff(moving_seg,warped_baseline.squeeze(),8)
print(d2, d2.mean())
d3 = dice_coeff(fixed_seg,moving_seg,8)
print(d3, d3.mean())

plt.subplot(131)
plt.title("PDD-Output")
plt.imshow(showFlow(pred_xy.transpose(0,1).reshape(2,grid_size,grid_size).detach()))

plt.subplot(132)
plt.title('Fixed')
plt.imshow(fixed.detach().squeeze())

plt.subplot(133)
plt.title('Moving')
plt.imshow(moving.detach().squeeze())

## Train Pdd with Flow2 as teacher
Only try to learn the flow estimation from the flownet

This is to approximate the teacher and only the teacher output

In [None]:
class_weight = torch.sqrt(1.0/(torch.bincount(segs.view(-1)).float()))
class_weight = class_weight/class_weight.mean()
class_weight[0] = 0.15
class_weight = class_weight.cpu()

In [None]:
from torch.autograd import Variable
epochs = 100
lr = 0.025
#delta = 0.8

net = OBELISK2d()
reg = deeds2d()
net.train()
reg.train()

init_weights(net)

optimizer = torch.optim.Adam(params=list(net.parameters()) + list(reg.parameters()), lr=lr)

torch.manual_seed(10)
# Now, we prepare our train & test dataset.
test_set = torch.LongTensor([35, 41, 0, 4, 39, 17])
train_set = torch.arange(43)
for idx in test_set:
    train_set = train_set[train_set != idx]

diffs=[]
for epoch in range(epochs):

    # get training examples
    rnd_train_idx = torch.randperm(train_set.size(0))
    p_fix = train_set[rnd_train_idx[0]]
    p_mov = train_set[rnd_train_idx[1]]

    fixed = imgs[p_fix:p_fix+1,:,:] / 255
    moving = imgs[p_mov:p_mov+1,:,:] / 255
    #Flow output from teacher1
    
    fixed_seg = segs[p_fix:p_fix+1,:,:]
    moving_seg = segs[p_mov:p_mov+1,:,:]
        
    flow_in = preprocessing_flownet(fixed.reshape(h,w,C),moving.reshape(h,w,C)).cuda()
    teacher_flow = flow2(flow_in).cpu()
    warped_moving_teacher = warpImage(labelMatrixOneHot(fixed_seg,9), teacher_flow)

    
    optimizer.zero_grad()
    
    fixed = Variable(fixed, requires_grad=True)           #fixed
    moving = Variable(moving, requires_grad=True)         #moving
    
    feat00 = net(fixed.unsqueeze(0))
    feat50 = net(moving.unsqueeze(0))

    # Need some help here. I am not using the soft_cost at all. Not sure how tho
    soft_cost, pred_xy = reg(feat00,feat50)
    
    pred_xy = pred_xy.view(1,grid_size,grid_size,2)
    
    diffloss = 1.5*((pred_xy[0,:,1:,:]-pred_xy[0,:,:-1,:])**2).mean()+\
            1.5*((pred_xy[0,1:,:,:]-pred_xy[0,:-1,:,:])**2).mean()+\
            1.5*((pred_xy[0,:,:,1:]-pred_xy[0,:,:,:-1])**2).mean()
    
    warped_seg = (F.grid_sample(labelMatrixOneHot(fixed_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    warped_teacher = F.grid_sample(warped_moving_teacher, grid_xy).detach()

    teacher_loss = torch.pow((warped_teacher - warped_seg),2)
    loss = teacher_loss.mean()
    #loss = (3/4 * orig_labelloss + (1/4 * teacher_loss)).mean()
    loss.backward()
    diffs.append(loss.item())
    
    #diff.backward()
    #diffs.append(diff.item())
    optimizer.step()

In [None]:
plt.plot(np.arange(epochs), diffs)

In [None]:
net.eval()
reg.eval()

rnd_test_idx = torch.randperm(test_set.size(0))
p_fix = test_set[rnd_test_idx[0]]
p_mov = test_set[rnd_test_idx[1]]

fixed = imgs[p_fix:p_fix+1,:,:] / 255
moving = imgs[p_mov:p_mov+1,:,:] / 255

fixed_seg = segs[p_fix:p_fix+1,:,:]
moving_seg = segs[p_mov:p_mov+1,:,:]

feat1 = net(fixed.unsqueeze(0))
feat2 = net(moving.unsqueeze(0))

soft_cost, pred_xy = reg(feat1,feat2)

dense_flow_fit = F.interpolate(pred_xy.transpose(0,1).view(1,2,grid_size,grid_size),size=(h,w),mode='bicubic')
#apply and evaluate transformation
identity = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,h,w),align_corners=False)
warped_seg = F.grid_sample(fixed_seg.float().unsqueeze(1),identity+dense_flow_fit.permute(0,2,3,1),mode='nearest',align_corners=False)

in1 = fixed.view(h,w,1).detach().numpy().astype(np.float32) 
in2 = moving.view(h,w,1).detach().numpy().astype(np.float32) 

baseline_flow = baseline.calc(in1,in2,None)
warped_baseline = warpImage(fixed_seg.unsqueeze(0).cpu().float(), torch.from_numpy(baseline_flow).unsqueeze(0).view(1,2,h,w).cpu())

d1 = dice_coeff(moving_seg,warped_seg.squeeze(),9)
print("PDD-Student: ", d1,d1.mean())
d2 = dice_coeff(moving_seg, warped_baseline, 9)
print("Baseline: ", d2, d2.mean())
d3 = dice_coeff(fixed_seg,moving_seg,9)
print("diff fixed, moving: ", d3, d3.mean())

## Using Flow2 as teacher and training on regular data
combining the teacher loss with the data loss

With a very high delta (not using the flownet it works...)

In [None]:
from torch.autograd import Variable
epochs = 1000
lr = 0.005
delta = 0.5

net = OBELISK2d()
reg = deeds2d()
net.train()
reg.train()

init_weights(net)

optimizer = torch.optim.Adam(params=list(net.parameters()) + list(reg.parameters()), lr=lr)

disp_range = 0.25#0.25
displacement_width = 15#15#11#17
shift_xy = F.affine_grid(disp_range*torch.eye(2,3).unsqueeze(0),(1,1,displacement_width,displacement_width)).view(1,1,-1,2) 
grid_xy = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,grid_size,grid_size)).view(1,-1,1,2)

total_loss = []

kl_loss = torch.nn.KLDivLoss()

torch.manual_seed(10)

pat_indices = torch.cat((torch.arange(0, 17), torch.arange(18, 43)), 0)

rnd_perm_idc = torch.randperm(pat_indices.size(0))
pat_indices = pat_indices[rnd_perm_idc]

# Now, we prepare our train & test dataset.
test_set = torch.LongTensor([35, 41, 0, 4, 39, 17])
train_set = torch.arange(43)
for idx in test_set:
    train_set = train_set[train_set != idx]

diffs = []
for epoch in range(epochs):
     # get training examples
    rnd_train_idx = torch.randperm(train_set.size(0))
    p_fix = train_set[rnd_train_idx[0]]
    p_mov = train_set[rnd_train_idx[1]]

    fixed = imgs[p_fix:p_fix+1,:,:] / 255
    moving = imgs[p_mov:p_mov+1,:,:] / 255
    #Flow output from teacher1
    
    fixed_seg = segs[p_fix:p_fix+1,:,:]
    moving_seg = segs[p_mov:p_mov+1,:,:]
        
    flow_in = preprocessing_flownet(fixed.reshape(h,w,C),moving.reshape(h,w,C)).cuda()
    teacher_flow = flow2(flow_in).cpu()
    warped_moving_teacher = warpImage(labelMatrixOneHot(moving_seg,9), teacher_flow)

    
    optimizer.zero_grad()
    
    fixed = Variable(fixed, requires_grad=True)           #fixed
    moving = Variable(moving, requires_grad=True)         #moving
    
    feat00 = net(fixed.unsqueeze(0))
    feat50 = net(moving.unsqueeze(0))

    # Need some help here. I am not using the soft_cost at all. Not sure how tho
    soft_cost, pred_xy = reg(feat00,feat50)
    
    pred_xy = pred_xy.view(1,grid_size,grid_size,2)
    
    diffloss = 1.5*((pred_xy[0,:,1:,:]-pred_xy[0,:,:-1,:])**2).mean()+\
            1.5*((pred_xy[0,1:,:,:]-pred_xy[0,:-1,:,:])**2).mean()+\
            1.5*((pred_xy[0,:,:,1:]-pred_xy[0,:,:,:-1])**2).mean()
    
    warped_seg = (F.grid_sample(labelMatrixOneHot(moving_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    warped_teacher = F.grid_sample(warped_moving_teacher, grid_xy).detach()

    teacher_loss = torch.pow((warped_seg - warped_teacher),2).mean()
    
    nonlocal_label = (F.grid_sample(labelMatrixOneHot(moving_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    fixed_label = F.grid_sample(labelMatrixOneHot(fixed_seg,9).float(),grid_xy,padding_mode='border')#.long().squeeze(1)

    orig_labelloss = ((nonlocal_label-fixed_label)**2).mean()
    
    
    loss = (diffloss + (1-delta)*teacher_loss + delta*orig_labelloss)
    loss.backward()
    diffs.append(loss.item())
    
    #diff.backward()
    #diffs.append(diff.item())
    optimizer.step()

In [None]:
plt.plot(np.arange(epochs), diffs)

In [None]:
net.eval()
reg.eval()

rnd_test_idx = torch.randperm(test_set.size(0))
p_fix = test_set[rnd_test_idx[0]]
p_mov = test_set[rnd_test_idx[1]]

fixed = imgs[p_fix:p_fix+1,:,:] / 255
moving = imgs[p_mov:p_mov+1,:,:] / 255

fixed_seg = segs[p_fix:p_fix+1,:,:]
moving_seg = segs[p_mov:p_mov+1,:,:]

feat1 = net(fixed.unsqueeze(0))
feat2 = net(moving.unsqueeze(0))

soft_cost, pred_xy = reg(feat1,feat2)

dense_flow_fit = F.interpolate(pred_xy.transpose(0,1).view(1,2,grid_size,grid_size),size=(h,w),mode='bicubic')
#apply and evaluate transformation
identity = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,h,w),align_corners=False)
warped_seg = F.grid_sample(moving_seg.float().unsqueeze(1),identity+dense_flow_fit.permute(0,2,3,1),mode='nearest',align_corners=False)

in1 = fixed.view(h,w,1).detach().numpy().astype(np.float32) 
in2 = moving.view(h,w,1).detach().numpy().astype(np.float32) 

baseline_flow = baseline.calc(in1,in2,None)
warped_baseline = warpImage(moving_seg.unsqueeze(0).cpu().float(), torch.from_numpy(baseline_flow).unsqueeze(0).view(1,2,h,w).cpu())

d1 = dice_coeff(fixed_seg,warped_seg.squeeze(),8)
print("PDD-Student: ", d1,d1.mean())
d2 = dice_coeff(fixed_seg, warped_baseline, 8)
print("Baseline: ", d2, d2.mean())
d3 = dice_coeff(fixed_seg,moving_seg,9)
print("diff fixed, moving: ", d3, d3.mean())

## Playing with KLDiv
Using klDiv and the learning diff and loss.
Setting delta to different values, and check, which one learnes the best

In [None]:
from torch.autograd import Variable
epochs = 1000
lr = 0.0025
delta = 0.5

net = OBELISK2d()
reg = deeds2d()
net.train()
reg.train()

init_weights(net)

optimizer = torch.optim.Adam(params=list(net.parameters()) + list(reg.parameters()), lr=lr)

disp_range = 0.25#0.25
displacement_width = 15#15#11#17
shift_xy = F.affine_grid(disp_range*torch.eye(2,3).unsqueeze(0),(1,1,displacement_width,displacement_width)).view(1,1,-1,2) 
grid_xy = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,grid_size,grid_size)).view(1,-1,1,2)

total_loss = []

kl_loss = torch.nn.KLDivLoss()

torch.manual_seed(10)

pat_indices = torch.cat((torch.arange(0, 17), torch.arange(18, 43)), 0)

rnd_perm_idc = torch.randperm(pat_indices.size(0))
pat_indices = pat_indices[rnd_perm_idc]

# Now, we prepare our train & test dataset.
test_set = torch.LongTensor([35, 41, 0, 4, 39, 17])
train_set = torch.arange(43)
for idx in test_set:
    train_set = train_set[train_set != idx]

diffs = []
for epoch in range(epochs):
     # get training examples
    rnd_train_idx = torch.randperm(train_set.size(0))
    p_fix = train_set[rnd_train_idx[0]]
    p_mov = train_set[rnd_train_idx[1]]

    fixed = imgs[p_fix:p_fix+1,:,:] / 255
    moving = imgs[p_mov:p_mov+1,:,:] / 255
    #Flow output from teacher1
    
    fixed_seg = segs[p_fix:p_fix+1,:,:]
    moving_seg = segs[p_mov:p_mov+1,:,:]
        
    flow_in = preprocessing_flownet(fixed.reshape(h,w,C),moving.reshape(h,w,C)).cuda()
    teacher_flow = flow2(flow_in).cpu()
    warped_moving_teacher = warpImage(labelMatrixOneHot(moving_seg,9), teacher_flow)

    
    optimizer.zero_grad()
    
    fixed = Variable(fixed, requires_grad=True)           #fixed
    moving = Variable(moving, requires_grad=True)         #moving
    
    feat00 = net(fixed.unsqueeze(0))
    feat50 = net(moving.unsqueeze(0))

    # Need some help here. I am not using the soft_cost at all. Not sure how tho
    soft_cost, pred_xy = reg(feat00,feat50)
    
    pred_xy = pred_xy.view(1,grid_size,grid_size,2)
    
    diffloss = 1.5*((pred_xy[0,:,1:,:]-pred_xy[0,:,:-1,:])**2).mean()+\
            1.5*((pred_xy[0,1:,:,:]-pred_xy[0,:-1,:,:])**2).mean()+\
            1.5*((pred_xy[0,:,:,1:]-pred_xy[0,:,:,:-1])**2).mean()
    
    warped_seg = (F.grid_sample(labelMatrixOneHot(moving_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    warped_teacher = F.grid_sample(warped_moving_teacher, grid_xy).detach()

    teacher_loss = - kl_loss(warped_seg,warped_teacher)
    
    
    nonlocal_label = (F.grid_sample(labelMatrixOneHot(moving_seg,9).float(),grid_xy+shift_xy,padding_mode='border')\
                          *soft_cost.view(1,-1,grid_size**2,displacement_width**2)).sum(1,keepdim=True)
    fixed_label = F.grid_sample(labelMatrixOneHot(fixed_seg,9).float(),grid_xy,padding_mode='border')#.long().squeeze(1)

    orig_labelloss = ((nonlocal_label-fixed_label)**2).mean()

    loss = diffloss + (1-delta)*teacher_loss + delta*orig_labelloss
    loss.backward()
    diffs.append(loss.item())
    
    #diff.backward()
    #diffs.append(diff.item())
    optimizer.step()

In [None]:
plt.plot(np.arange(epochs), diffs)

In [None]:
net.eval()
reg.eval()

rnd_test_idx = torch.randperm(test_set.size(0))
p_fix = test_set[rnd_test_idx[0]]
p_mov = test_set[rnd_test_idx[1]]

fixed = imgs[p_fix:p_fix+1,:,:] / 255
moving = imgs[p_mov:p_mov+1,:,:] / 255

fixed_seg = segs[p_fix:p_fix+1,:,:]
moving_seg = segs[p_mov:p_mov+1,:,:]

feat1 = net(fixed.unsqueeze(0))
feat2 = net(moving.unsqueeze(0))

soft_cost, pred_xy = reg(feat1,feat2)

dense_flow_fit = F.interpolate(pred_xy.transpose(0,1).view(1,2,grid_size,grid_size),size=(h,w),mode='bicubic')
#apply and evaluate transformation
identity = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,h,w),align_corners=False)
warped_seg = F.grid_sample(moving_seg.float().unsqueeze(1),identity+dense_flow_fit.permute(0,2,3,1),mode='nearest',align_corners=False)

in1 = fixed.view(h,w,1).detach().numpy().astype(np.float32) 
in2 = moving.view(h,w,1).detach().numpy().astype(np.float32) 

baseline_flow = baseline.calc(in1,in2,None)
warped_baseline = warpImage(moving_seg.unsqueeze(0).cpu().float(), torch.from_numpy(baseline_flow).unsqueeze(0).view(1,2,h,w).cpu())

d1 = dice_coeff(fixed_seg,warped_seg.squeeze(),8)
print("PDD-Student: ", d1,d1.mean())
d2 = dice_coeff(fixed_seg, warped_baseline, 8)
print("Baseline: ", d2, d2.mean())
d3 = dice_coeff(fixed_seg,moving_seg,8)
print("diff fixed, moving: ", d3, d3.mean())

# Verdict
seems to outperform the baseline somehow...
Actually, is this stuff working??

# 16.06.21 Comparing obelisk to CNN