# Second Experiment
In this experiment, we want to investigate, the PDD network can improove the performance, when using a knowledge compression method. The teacher will be the FlowNet2 with pretrained weights, which can be found [here](https://drive.google.com/drive/folders/0B5EC7HMbyk3CbjFPb0RuODI3NmM?resourcekey=0-SuPU4qVZzuB4s83ngjrAEg). 

As there are multiple different way in applying the loss function, when using knowledge compression, there will be three different loss functions compared here. 
1) In a standard knowledge compression, the KL divergence is used to approximate the teacher output. We will compare the different warped labels of student and teacher. However, as this is usually done in classificatin, we do not expect it to perform really well in this setting.

2) A different approach is to use the MSE between the estimated flow field of the teacher and the estimated flowfield of the student. This does, however, not incorporate the warped label and dice score. The MSE is expressed as: $$L =\sum_{i,j}(prediction[i,j] - target[i,j])^2 $$ where $[i,j]$ denotes the flowfield at position i,j. 

3) What is missing in 2) will be used in the last setting. We incorporate the loss of the warped label. This can happen either by simply adding the loss to the loss between the flow fields OR by using a scale factor (probably dice score of the teacher) which indicates how well the teacher performs on this specific example. The loss can then be expressed as follows: $$L_i = \delta * L_{flowfield} + (1-\delta) * L_{Label}$$ With $\delta$ being the dice score between warped teacher segment and fixed segmentation. $L_i$ is the loss ofr the $i_{th}$ example, $L_{flowfield}$ is the crossentropy loss between student and teacher flow estimation and $L_{Label}$ is the warping loss between student warped segmentation and the segmentation to approximate.

In [1]:
import numpy as np
import cv2
from PIL import Image
import os
from tqdm import tqdm
%autosave 18000
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from math import ceil

from utils.preprocessing import preprocessing_flownet, preprocessing_pwc
from utils.load_models import load_flownet2, load_pwcnet, init_weights
from utils.plotting import flow2img, overlaySegment, showFlow
from utils.layers import warp, warpImage
from utils.encoding import labelMatrixOneHot, dice_coeff


from models.pdd_net.pdd_student import OBELISK2d

# Select a GPU for the work
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
available_gpus = [(torch.cuda.device(i),torch.cuda.get_device_name(i)) for i in range(torch.cuda.device_count())]
print(available_gpus)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

Autosaving every 18000 seconds
[(<torch.cuda.device object at 0x7fe7609c6520>, 'GeForce RTX 2080 Ti')]


device(type='cuda', index=0)

# Data

In [2]:
imgs = torch.load('data/train_frames.pth')
segs = torch.load('data/train_segs.pth')

#define a training split 
torch.manual_seed(42)
# Now, we prepare our train & test dataset.
train_set = torch.from_numpy(np.random.choice(np.arange(len(imgs)),size=int(len(imgs)*0.95), replace=False))

test_set = torch.arange(len(imgs))
for idx in train_set:
    test_set = test_set[test_set != idx]


print(f"{train_set.shape[0]} train examples")
print(f"{test_set.shape[0]} test examples")

1135 train examples
60 test examples


In [3]:
for i, seg in enumerate(segs):
    f_seg = seg[0]
    m_seg = seg[1]
    if len(torch.where(torch.histc(f_seg) != 0)[0]) == 3 and f_seg.max() <= 2:
        segs[i][0] = segs[i][0]*2 
    if len(torch.where(torch.histc(m_seg) != 0)[0]) == 3 and m_seg.max() <= 2:
        segs[i][1] = segs[i][1]*2 

# Student

In [4]:
W,H = (150,150)
o_m = H//4
o_n = W//4
ogrid_xy = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,o_m,o_n)).view(1,1,-1,2).cuda()
disp_range = 0.25#0.25
displacement_width = 15#15#11#17
shift_xy = F.affine_grid(disp_range*torch.eye(2,3).unsqueeze(0),(1,1,displacement_width,displacement_width)).view(1,1,-1,2).cuda()

grid_size = 32#25#30
grid_xy = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,grid_size,grid_size)).view(1,-1,1,2).cuda()


def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
        nn.init.xavier_normal(m.weight)
        if m.bias is not None:
            nn.init.constant(m.bias, 0.0)

class OBELISK2d(nn.Module):
    def __init__(self, chan = 16):

        super(OBELISK2d, self).__init__()
        channels = chan
        self.offsets = nn.Parameter(torch.randn(2,channels *2,2) *0.05)
        self.layer0 = nn.Conv2d(1, 4, 5, stride=2, bias=False, padding=2)
        self.batch0 = nn.BatchNorm2d(4)

        self.layer1 = nn.Conv2d(channels *8, channels *4, 1, bias=False, groups=1)
        self.batch1 = nn.BatchNorm2d(channels *4)
        self.layer2 = nn.Conv2d(channels *4, channels *4, 3, bias=False, padding=1)
        self.batch2 = nn.BatchNorm2d(channels *4)
        self.layer3 = nn.Conv2d(channels *4, channels *1, 1)
        

    def forward(self, input_img):
        img_in = F.avg_pool2d(input_img ,3 ,padding=1 ,stride=2)
        img_in = F.relu(self.batch0(self.layer0(img_in)))
        sampled = F.grid_sample(img_in ,ogrid_xy + self.offsets[0 ,:,:].view(1 ,-1 ,1 ,2)).view(1 ,-1 ,o_m ,o_n)
        sampled -= F.grid_sample(img_in ,ogrid_xy + self.offsets[1 ,:,:].view(1 ,-1 ,1 ,2)).view(1 ,-1 ,o_m ,o_n)

        x = F.relu(self.batch1(self.layer1(sampled)))
        x = F.relu(self.batch2(self.layer2(x)))
        features = self.layer3(x)
        return features

In [5]:
def min_convolution(ssd_distance, displace_range, H, W):
    # Prepare operators for smooth dense displacement space
    pad1 = nn.ReplicationPad2d(5)
    avg1 = nn.AvgPool2d(5,stride=1)
    max1 = nn.MaxPool2d(3,stride=1)
    pad2 = nn.ReplicationPad2d(6)
    # approximate min convolution / displacement compatibility

    ssd_minconv = avg1(avg1(-max1(-pad1(ssd_distance.permute(0,2,3,1).reshape(1,-1,displace_range,displace_range)))))

    ssd_minconv = ssd_minconv.permute(0,2,3,1).view(1,-1,H,W)
    min_conv_cost = avg1(avg1(avg1(pad2(ssd_minconv))))
    
    return min_conv_cost

def meanfield(ssd_distance,img_fixed,displace_range,H,W):

    crnt_dev = ssd_distance.device

    cost = min_convolution(ssd_distance, displace_range, H, W)

    soft_cost = F.softmax(-10*cost.view(displace_range**2,-1).t(),1)
    
    disp_hw = (displace_range-1)//2
    disp_mesh_grid = disp_hw*F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,displace_range,displace_range),align_corners=True)
    disp_mesh_grid /= torch.Tensor([(W-1)*.5,(H-1)*.5])

    disp_xy = torch.sum(soft_cost.view(1,H,W,-1,1)*disp_mesh_grid.view(1,1,1,-1,2).to(crnt_dev),3).permute(0,3,1,2) 
    

    return soft_cost,disp_xy

def correlation_layer(displace_range, feat_moving, feat_fixed):
    
    disp_hw = (displace_range-1)//2
    feat_moving_unfold = F.unfold(feat_moving.transpose(1,0),(displace_range,displace_range),padding=disp_hw)
    B,C,H,W = feat_fixed.size()
    
    ssd_distance = ((feat_moving_unfold-feat_fixed.view(C,1,-1))**2).sum(0).view(1,displace_range**2,H,W)

    return ssd_distance

# Teacher

In [6]:
flownet = load_flownet2().cuda()

# Eval function

In [7]:
def evaluate_model(model):
    model.eval()
    overall_dice = []
    unwarped_dice = []
    
    for i,idx in enumerate(test_set):
        
        fixed = imgs[idx:idx+1,0,:].unsqueeze(0).float()
        moving = imgs[idx:idx+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[idx:idx+1,0,:].contiguous()
        moving_seg = segs[idx:idx+1,1,:].contiguous()
        
        # Some images have no segmentation to them, 
        # even if it was present in the directory
        # We leave these ones out, as they cannot be avaluated
        # we need to check the max and min, so we can be sure to get the two labels, if there are some
        if len(torch.where(torch.histc(fixed_seg) != 0)[0]) == 3 and fixed_seg.max() <= 2:
            fixed_seg = fixed_seg*2
        if len(torch.where(torch.histc(moving_seg) != 0)[0]) == 3 and moving_seg.max() <= 2:
            moving_seg = moving_seg*2
        if fixed_seg.max() < 1:
            pass
        else:
        
            with torch.no_grad():
                fixed_feat = model(fixed.cuda())
                moving_feat = model(moving.cuda())

            ssd_distance = correlation_layer(displace_range, moving_feat, fixed_feat).contiguous()
            #regularise using meanfield inference with approx. min-convolutions
            soft_cost_one,disp_xy = meanfield(ssd_distance, fixed, displace_range, H//4, W//4)
            #upsample field to original resolution
            dense_flow_fit = F.interpolate(disp_xy,size=(H,W),mode='bicubic')


            #apply and evaluate transformation
            identity = F.affine_grid(torch.eye(2,3).unsqueeze(0),(1,1,H,W),align_corners=False).cuda()
            warped_student_seg = F.grid_sample(moving_seg.cuda().float().unsqueeze(1),identity+dense_flow_fit.permute(0,2,3,1),mode='nearest',align_corners=False).cpu()

            d1 = dice_coeff(fixed_seg,warped_student_seg.squeeze(),3)
            d2 = dice_coeff(fixed_seg, moving_seg, 3)
            
            overall_dice.append(d1.mean())
            unwarped_dice.append(d2.mean())
    
    overall_dice = torch.from_numpy(np.array(overall_dice))
    unwarped_dice = torch.from_numpy(np.array(unwarped_dice))
    
    print(f"This model has an average Dice of {round(overall_dice.mean().item(), 5)} mit Variance: {round(overall_dice.var().item(), 5)}. The unwarped Mean dice is: {round(unwarped_dice.mean().item(), 5)} with Var {round(unwarped_dice.var().item(),5)}")

In [8]:
def warp_seg(moving_seg, flow):
    """
    function to warp the segemntation of the teacher and baseline
    
    moving_seg: CxHxW
    flow: size: BxCxHxW
    """
    B, C, H, W = flow.size()
    # mesh grid
    xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
    yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
    xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
    yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    grid = torch.cat((xx, yy), 1).float().to(flow.device)
    
    vgrid = grid + flow

    # scale grid to [-1,1]
    vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
    vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0

    vgrid = vgrid.permute(0, 2, 3, 1)
    warped_seg_grid = nn.functional.grid_sample(moving_seg.float().unsqueeze(0), vgrid)
    return warped_seg_grid

# Setting Parameters

In [9]:
disp_hw = 5
displace_range = 11

epochs = 100
lr = 0.0002
# minibatch training
grad_accum = 4

# Experiment 2.1
Using KL_divergence between student and teacher warped label

In [10]:
student = OBELISK2d(24)
init_weights(student)
student.train().cuda()

optimizer = torch.optim.Adam(list(student.parameters()),lr=lr)

In [11]:
kl_loss = torch.nn.KLDivLoss()

In [12]:
for epoch in tqdm(range(epochs)):

    # Shuffle training examples
    rnd_train_idx = torch.randperm(train_set.size(0))

    # show all examples to model
    for rnd_idx in rnd_train_idx:
        
        p_fix = train_set[rnd_idx]

        # Get image and segmentation
        fixed = imgs[p_fix:p_fix+1,0,:].unsqueeze(0).float()
        moving = imgs[p_fix:p_fix+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[p_fix:p_fix+1,0,:].long().contiguous()
        moving_seg = segs[p_fix:p_fix+1,1,:].long().contiguous()

        # Here we have to rescale the images for the flownet
        # the flownet expects intputs that match x*64, when m and n < 200
        teacher_fixed = F.interpolate(fixed, size=(128,128))
        teacher_moving = F.interpolate(moving, size=(128,128))
        # Generate the teacher flow estimation
        flow_in = preprocessing_flownet(teacher_fixed.detach().clone().reshape(128,128,1),teacher_moving.detach().clone().reshape(128,128,1)).cuda()
        teacher_flow = flownet(flow_in)
        teacher_flow = F.interpolate(teacher_flow, size=(H//4,W//4), mode='bilinear')
        
        # Label preparation
        C1,Hf,Wf = moving_seg.size()
        label_moving_onehot = F.one_hot(moving_seg,num_classes=3).permute(0,3,1,2).float()
        label_moving = F.interpolate(label_moving_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        label_fixed = F.one_hot(fixed_seg,num_classes=3).permute(0,3,1,2).float()
        label_fixed = F.interpolate(label_fixed,size=(Hf//4,Wf//4),mode='bilinear')
        # generate the "unfolded" version of the moving encoding that will result in the shifted versions per channel
        # according to the corresponding discrete displacement pair
        label_moving_unfold = F.unfold(label_moving,(displace_range,displace_range),padding=disp_hw).view(1,3,displace_range**2,-1)
        

        feat00 = student(fixed.cuda())
        feat50 = student(moving.cuda())

        # compute the cost tensor using the correlation layer
        ssd_distance = correlation_layer(displace_range, feat50, feat00)

        # compute the MIN-convolution & probabilistic output with the given function
        soft_cost,disp_xy = meanfield(ssd_distance, fixed, displace_range, H//4, W//4)

        # warp the label
        label_warped = torch.sum(soft_cost.cpu().t().unsqueeze(0)*label_moving_unfold.squeeze(0),1)
        # warp segment with teacher flow
        warped_teacher_seg = warp_seg(label_moving.squeeze().float().cuda(),teacher_flow.cuda()).cpu()
        
        loss = kl_loss(label_warped,warped_teacher_seg.view(3,-1))
        loss.backward()
        if (epoch+1)%grad_accum == 0:
            # every grad_accum iterations :Make an optimizer step
            optimizer.step()
            optimizer.zero_grad()

100%|██████████| 100/100 [1:40:33<00:00, 60.34s/it]


In [13]:
evaluate_model(student)

This model has an average Dice of 0.34468 mit Variance: 0.02055. The unwarped Mean dice is: 0.32494 with Var 0.0243


In [14]:
torch.save(student.state_dict(), "models/obel_klDiv.pth")

# Experiment 2.2
Using MSE loss between the flow predictions. 

In [15]:
student = OBELISK2d(24)
init_weights(student)
student.train().cuda()

optimizer = torch.optim.Adam(list(student.parameters()),lr=lr)

In [16]:
for epoch in tqdm(range(epochs)):

    # Shuffle training examples
    rnd_train_idx = torch.randperm(train_set.size(0))

    # show all examples to model
    for rnd_idx in rnd_train_idx:
        
        p_fix = train_set[rnd_idx]

        # Get image and segmentation
        fixed = imgs[p_fix:p_fix+1,0,:].unsqueeze(0).float()
        moving = imgs[p_fix:p_fix+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[p_fix:p_fix+1,0,:].long().contiguous()
        moving_seg = segs[p_fix:p_fix+1,1,:].long().contiguous()

        # Here we have to rescale the images for the flownet
        # the flownet expects intputs that match x*64, when m and n < 200
        teacher_fixed = F.interpolate(fixed, size=(128,128))
        teacher_moving = F.interpolate(moving, size=(128,128))
        # Generate the teacher flow estimation
        flow_in = preprocessing_flownet(teacher_fixed.detach().clone().reshape(128,128,1),teacher_moving.detach().clone().reshape(128,128,1)).cuda()
        teacher_flow = flownet(flow_in)
        teacher_flow = F.interpolate(teacher_flow, size=(H//4,W//4), mode='bilinear')
        
        # Label preparation
        #C1,Hf,Wf = moving_seg.size()
        #label_moving_onehot = F.one_hot(moving_seg,num_classes=3).permute(0,3,1,2).float()
        #label_moving = F.interpolate(label_moving_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        #label_fixed_onehot = F.one_hot(fixed_seg,num_classes=3).permute(0,3,1,2).float()
        #label_fixed = F.interpolate(label_fixed_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        # generate the "unfolded" version of the moving encoding that will result in the shifted versions per channel
        # according to the corresponding discrete displacement pair
        #label_moving_unfold = F.unfold(label_moving,(displace_range,displace_range),padding=disp_hw).view(1,3,displace_range**2,-1)

        feat00 = student(fixed.cuda())
        feat50 = student(moving.cuda())

        # compute the cost tensor using the correlation layer
        ssd_distance = correlation_layer(displace_range, feat50, feat00)

        # compute the MIN-convolution & probabilistic output with the given function
        soft_cost,disp_xy = meanfield(ssd_distance, fixed, displace_range, H//4, W//4)
        
        #loss = cross_entropy_loss(disp_xy,teacher_flow)
        loss = torch.sum(torch.pow(disp_xy - teacher_flow, 2))
        loss.backward()
        if (epoch+1)%grad_accum == 0:
            # every grad_accum iterations :Make an optimizer step
            optimizer.step()
            optimizer.zero_grad()

100%|██████████| 100/100 [1:34:52<00:00, 56.93s/it]


In [17]:
evaluate_model(student)

This model has an average Dice of 0.33651 mit Variance: 0.02122. The unwarped Mean dice is: 0.32494 with Var 0.0243


In [18]:
torch.save(student.state_dict(), "models/obel_MSE_teacher_flow.pth")

# Experiment 2.3
Using warped label loss as simple addition to the MSE loss

In [19]:
student = OBELISK2d(24)
init_weights(student)
student.train().cuda()

optimizer = torch.optim.Adam(list(student.parameters()),lr=lr)

In [20]:
for epoch in tqdm(range(epochs)):

    # Shuffle training examples
    rnd_train_idx = torch.randperm(train_set.size(0))

    # show all examples to model
    for rnd_idx in rnd_train_idx:
        
        p_fix = train_set[rnd_idx]

        # Get image and segmentation
        fixed = imgs[p_fix:p_fix+1,0,:].unsqueeze(0).float()
        moving = imgs[p_fix:p_fix+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[p_fix:p_fix+1,0,:].long().contiguous()
        moving_seg = segs[p_fix:p_fix+1,1,:].long().contiguous()

        # Here we have to rescale the images for the flownet
        # the flownet expects intputs that match x*64, when m and n < 200
        teacher_fixed = F.interpolate(fixed, size=(128,128))
        teacher_moving = F.interpolate(moving, size=(128,128))
        # Generate the teacher flow estimation
        flow_in = preprocessing_flownet(teacher_fixed.detach().clone().reshape(128,128,1),teacher_moving.detach().clone().reshape(128,128,1)).cuda()
        teacher_flow = flownet(flow_in)
        teacher_flow = F.interpolate(teacher_flow, size=(H//4,W//4), mode='bilinear')
        
        # Label preparation
        C1,Hf,Wf = moving_seg.size()
        label_moving_onehot = F.one_hot(moving_seg,num_classes=3).permute(0,3,1,2).float()
        label_moving = F.interpolate(label_moving_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        label_fixed_onehot = F.one_hot(fixed_seg,num_classes=3).permute(0,3,1,2).float()
        label_fixed = F.interpolate(label_fixed_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        # generate the "unfolded" version of the moving encoding that will result in the shifted versions per channel
        # according to the corresponding discrete displacement pair
        label_moving_unfold = F.unfold(label_moving,(displace_range,displace_range),padding=disp_hw).view(1,3,displace_range**2,-1)

        feat00 = student(fixed.cuda())
        feat50 = student(moving.cuda())

        # compute the cost tensor using the correlation layer
        ssd_distance = correlation_layer(displace_range, feat50, feat00)

        # compute the MIN-convolution & probabilistic output with the given function
        soft_cost,disp_xy = meanfield(ssd_distance, fixed, displace_range, H//4, W//4)
        
        # warp the label
        label_warped = torch.sum(soft_cost.cpu().t().unsqueeze(0)*label_moving_unfold.squeeze(0),1)
        # warp segment with teacher flow
        warped_teacher_seg = warp_seg(label_moving.squeeze().float().cuda(),teacher_flow.cuda()).cpu()
        
        teacher_loss = torch.sum(torch.pow(disp_xy - teacher_flow, 2))
        label_distance1 = torch.sum(torch.pow(label_fixed.reshape(3,-1)-label_warped.reshape(3,-1),2),0)
        loss = teacher_loss + label_distance1.mean()
        if (epoch+1)%grad_accum == 0:
            # every grad_accum iterations :Make an optimizer step
            optimizer.step()
            optimizer.zero_grad()

100%|██████████| 100/100 [48:20<00:00, 29.00s/it]


In [21]:
evaluate_model(student)

This model has an average Dice of 0.30596 mit Variance: 0.02538. The unwarped Mean dice is: 0.32494 with Var 0.0243


In [22]:
torch.save(student.state_dict(), "models/obel_MSE_warped.pth")

# Experiment 2.4
Using the weighted label loss

In [10]:
student = OBELISK2d(24)
init_weights(student)
student.train().cuda()

optimizer = torch.optim.Adam(list(student.parameters()),lr=lr)

In [11]:
for epoch in tqdm(range(epochs)):

    # Shuffle training examples
    rnd_train_idx = torch.randperm(train_set.size(0))

    # show all examples to model
    for rnd_idx in rnd_train_idx:
        
        p_fix = train_set[rnd_idx]

        # Get image and segmentation
        fixed = imgs[p_fix:p_fix+1,0,:].unsqueeze(0).float()
        moving = imgs[p_fix:p_fix+1,1,:].unsqueeze(0).float()

        fixed_seg = segs[p_fix:p_fix+1,0,:].long().contiguous()
        moving_seg = segs[p_fix:p_fix+1,1,:].long().contiguous()

        # Here we have to rescale the images for the flownet
        # the flownet expects intputs that match x*64, when m and n < 200
        teacher_fixed = F.interpolate(fixed, size=(128,128))
        teacher_moving = F.interpolate(moving, size=(128,128))
        # Generate the teacher flow estimation
        flow_in = preprocessing_flownet(teacher_fixed.detach().clone().reshape(128,128,1),teacher_moving.detach().clone().reshape(128,128,1)).cuda()
        teacher_flow = flownet(flow_in)
        teacher_flow = F.interpolate(teacher_flow, size=(H//4,W//4), mode='bilinear')
        
        # Label preparation
        C1,Hf,Wf = moving_seg.size()
        label_moving_onehot = F.one_hot(moving_seg,num_classes=3).permute(0,3,1,2).float()
        label_moving = F.interpolate(label_moving_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        label_fixed_onehot = F.one_hot(fixed_seg,num_classes=3).permute(0,3,1,2).float()
        label_fixed = F.interpolate(label_fixed_onehot,size=(Hf//4,Wf//4),mode='bilinear')
        # generate the "unfolded" version of the moving encoding that will result in the shifted versions per channel
        # according to the corresponding discrete displacement pair
        label_moving_unfold = F.unfold(label_moving,(displace_range,displace_range),padding=disp_hw).view(1,3,displace_range**2,-1)

        feat00 = student(fixed.cuda())
        feat50 = student(moving.cuda())

        # compute the cost tensor using the correlation layer
        ssd_distance = correlation_layer(displace_range, feat50, feat00)

        # compute the MIN-convolution & probabilistic output with the given function
        soft_cost,disp_xy = meanfield(ssd_distance, fixed, displace_range, H//4, W//4)
        
        # warp the label
        label_warped = torch.sum(soft_cost.cpu().t().unsqueeze(0)*label_moving_unfold.squeeze(0),1)
        # warp segment with teacher flow
        warped_teacher_seg = warp_seg(moving_seg.float().cuda(),F.interpolate(teacher_flow, size=(H,W), mode='bilinear').cuda()).cpu()
        
        # calculate losses
        teacher_loss = torch.sum(torch.pow(disp_xy - teacher_flow, 2))
        label_distance1 = torch.sum(torch.pow(label_fixed.reshape(3,-1)-label_warped.reshape(3,-1),2),0)
        
        # calculate the dice score for the warped segmentation
        d0 = dice_coeff(warped_teacher_seg.squeeze(), fixed_seg.unsqueeze(0).float(), 3)
        
        # The loss can then be weighted
        loss = (d0.mean().item()) * teacher_loss + (1- d0.mean().item()) * label_distance1.mean()
        if (epoch+1)%grad_accum == 0:
            # every grad_accum iterations :Make an optimizer step
            optimizer.step()
            optimizer.zero_grad()

100%|██████████| 100/100 [50:04<00:00, 30.05s/it]


In [12]:
evaluate_model(student)

This model has an average Dice of 0.28963 mit Variance: 0.02797. The unwarped Mean dice is: 0.33263 with Var 0.02592


In [13]:
torch.save(student.state_dict(), "models/obel_MSE_weighted_warped.pth")

In [14]:
import IPython

IPython.Application.instance().kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}