# Convolutional Neural Networks with Intermediate Loss for 3D Super-Resolution of CT and MRI Scans

This notebook is a replication/exploration of the paper listed above.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from ..gen_utils.sr_gen import sr_gen # Custom class for image generation

## Define model

In [None]:
class CNNIL(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1,32,3, bias=False)
        self.conv2 = nn.Conv2d(32,32,3, bias=False)
        self.conv3 = nn.Conv2d(32,32,3, bias=False)
        self.conv4 = nn.Conv2d(32,32,3, bias=False)
        self.conv5 = nn.Conv2d(32,32,3, bias=False)
        self.conv6 = nn.Conv2d(32,4,3, bias=False)
        
        # Upscale step occurs here

        self.conv7 = nn.Conv2d(1,32,3, bias=False)
        self.conv8 = nn.Conv2d(32,32,3, bias=False)
        self.conv9 = nn.Conv2d(32,32,3, bias=False)
        self.conv10 = nn.Conv2d(32,1,3, bias = False)

    def forward(self, x):

        x = nn.ReLU(self.conv1(x))
        
        x_l = nn.ReLU(self.conv2(x_l))
        x_l = nn.ReLU(self.conv3(x_l))
        x_l = nn.ReLU(self.conv4(x_l+x))
        x_l = nn.ReLU(self.conv5(x_l))
        x_l = nn.ReLU(self.conv6(x_l+x))

        x_l = self.kern_upscale(x_l)

        x = nn.ReLU(self.conv7(x_l))
        x_h = nn.ReLU(self.conv8(x_h))
        x_h = nn.ReLU(self.conv9(x_h))
        x_h = nn.ReLU(self.conv10(x_h+x))

        return x_l, x_h #Return both results for the loss function


    def kern_upscale(x):
        # Function to do the unique upscaling pattern they propose in
        # the paper
        #TODO: There is going to be issues with dimensions here...
        s, c, h, w = x.shape

        x_up = torch.cat(torch.unbind(x),1)
        x_up = torch.transpose(torch.reshape(x_up,(h*c/2, w*c/2)),0,1)
        x_up = torch.cat(torch.split(x_up,2,1),2)
        x_up = torch.transpose(torch.reshape(x_up,(h*c/2, w*c/2)), 1,2)

        return x_up


In [54]:
# Testing of kern_upscale function

a = torch.tensor([[[1,2],[3,4]],
                [[5,6],[7,8]],
                [[9,10],[11,12]],
                [[13,14],[15,16]]])
a.shape

torch.Size([4, 2, 2])

In [55]:
# Concatinate each of the layers next to eachother
print(a.shape)
b = torch.cat(torch.unbind(a),1)
print(b.shape)
b = torch.reshape(b,(4,4))
print(b.shape)
b = torch.transpose(b,0,1)
b = torch.cat(torch.split(b,2,0),1)
torch.transpose(torch.reshape(b,(4,4)),0,1)

torch.Size([4, 2, 2])
torch.Size([2, 8])
torch.Size([4, 4])


tensor([[ 1,  5,  2,  6],
        [ 9, 13, 10, 14],
        [ 3,  7,  4,  8],
        [11, 15, 12, 16]])

## Set Optimization Parameters

In [None]:
net = CNNIL()

# They have a custom loss function that incorporates the final results and the result
# right after the upscaling step
# https://discuss.pytorch.org/t/custom-loss-functions/29387

## Generate Data for Training