# Convolutional Neural Networks with Intermediate Loss for 3D Super-Resolution of CT and MRI Scans

This notebook is a replication/exploration of the paper listed above.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
import sys
sys.path.append('..') # Stupid thing Python makes you do to import from a sibling directory
from gen_utils.sr_gen import sr_gen # Custom class for image generation

## Define model

In [None]:
class CNNIL(nn.Module):
    def __init__(self, upscale=2, axs = 'hw'):
        super().__init__()
        self.axs = axs

        if axs == 'hw':
            self.conv1 = nn.Conv2d(1,32,3, bias=False)
            self.conv2 = nn.Conv2d(32,32,3, bias=False)
            self.conv3 = nn.Conv2d(32,32,3, bias=False)
            self.conv4 = nn.Conv2d(32,32,3, bias=False)
            self.conv5 = nn.Conv2d(32,32,3, bias=False)
            self.conv6 = nn.Conv2d(32,4,3, bias=False)
            
            # Upscale step occurs here

            self.conv7 = nn.Conv2d(1,32,3, bias=False)
            self.conv8 = nn.Conv2d(32,32,3, bias=False)
            self.conv9 = nn.Conv2d(32,32,3, bias=False)
            self.conv10 = nn.Conv2d(32,1,3, bias = False)
        elif axs == 'h' or axs == 'w':
            self.conv1 = nn.Conv2d(1,32,3,bias=False)
            self.conv2 = nn.Conv2d(32,32,3, bias=False)
            self.conv3 = nn.Conv2d(32,32,3, bias=False)
            self.conv4 = nn.Conv2d(32,32,3, bias=False)
            self.conv5 = nn.Conv2d(32,32,3, bias=False)
            self.conv6 = nn.Conv2d(32,2,3, bias=False)
            
            # Upscale step occurs here

            self.conv7 = nn.Conv2d(1,32,3, bias=False)
            self.conv8 = nn.Conv2d(32,32,3, bias=False)
            self.conv9 = nn.Conv2d(32,32,3, bias=False)
            self.conv10 = nn.Conv2d(32,1,3, bias = False)

    def forward(self, x):

        x = nn.ReLU(self.conv1(x))
        
        x_l = nn.ReLU(self.conv2(x_l))
        x_l = nn.ReLU(self.conv3(x_l))
        x_l = nn.ReLU(self.conv4(x_l+x))
        x_l = nn.ReLU(self.conv5(x_l))
        x_l = nn.ReLU(self.conv6(x_l+x))

        x_l = self.kern_upscale(x_l, self.axs)

        x = nn.ReLU(self.conv7(x_l))
        x_h = nn.ReLU(self.conv8(x_h))
        x_h = nn.ReLU(self.conv9(x_h))
        x_h = nn.ReLU(self.conv10(x_h+x))

        return x_l, x_h #Return both results for the loss function


    def kern_upscale(x, axs='hw'):
        # Function to do the unique upscaling pattern they propose in
        # the paper
        #TODO: have them input a tuple for scale of the dimensions they wish to expand along
        s, c, h, w = x.shape

        if axs == 'hw':
            c = c/2
            x_up = torch.cat(torch.unbind(x,1),2)
            x_up = torch.reshape(x_up,(s,1,h*c,w*c))
            x_up = torch.transpose(x_up,2,3)
            x_up = torch.cat(torch.split(x_up,2,2),3)
            x_up = torch.transpose(torch.reshape(x_up,(s,1,h*c,w*c)),2,3)
        elif axs == 'h':
            x_up = torch.cat(torch.unbind(x,2),1)
            #x_up = torch.unsqueeze(x_up,0)
        elif axs == 'w':
            x_up = torch.cat(torch.unbind(x,3),1)
            x_up = torch.transpose(x_up,1,2)
            #x_up = torch.unsqueeze(x_up,0)
        else:
            print('No valid scaling dimension selected, returning False')
            x_up = False

        return x_up


In [2]:
# Testing of kern_upscale function with only 3D tensor

a = torch.tensor([[[1,2],[3,4]],
                [[5,6],[7,8]],
                [[9,10],[11,12]],
                [[13,14],[15,16]]])
a.shape

torch.Size([4, 2, 2])

In [3]:
# Concatinate each of the layers next to eachother
print(a.shape)
b = torch.cat(torch.unbind(a),1)
print(b.shape)
b = torch.reshape(b,(4,4))
print(b.shape)
b = torch.transpose(b,0,1)
b = torch.cat(torch.split(b,2,0),1)
torch.transpose(torch.reshape(b,(4,4)),0,1)

torch.Size([4, 2, 2])
torch.Size([2, 8])
torch.Size([4, 4])


tensor([[ 1,  5,  2,  6],
        [ 9, 13, 10, 14],
        [ 3,  7,  4,  8],
        [11, 15, 12, 16]])

In [4]:
# Testing of above kern_upscale function with 4D tensor (what the model will acually use)
a = torch.tensor([[[[1,2],[3,4]],
                [[5,6],[7,8]],
                [[9,10],[11,12]],
                [[13,14],[15,16]]]])
print(f'shape of a is {a.shape}')

b = torch.cat(torch.unbind(a,1),2)
b = torch.reshape(b,(1,1,4,4))
b = torch.transpose(b,2,3)
b = torch.cat(torch.split(b,2,2),3)
b = torch.transpose(torch.reshape(b,(1,1,4,4)),2,3)
b

shape of a is torch.Size([1, 4, 2, 2])


tensor([[[[ 1,  5,  2,  6],
          [ 9, 13, 10, 14],
          [ 3,  7,  4,  8],
          [11, 15, 12, 16]]]])

In [27]:
# Upscale along just one axis
a = torch.tensor([[[[1,2],[3,4]],
                [[5,6],[7,8]]]])
print(f'shape of a is {a.shape}')

# Version for doubling height
b = torch.cat(torch.unbind(a,2),2)
b = torch.reshape(b,(1,1,2,4))
b = torch.cat(torch.split(b,2,3),2)
# b = torch.unsqueeze(torch.cat(torch.unbind(a,2),1),0)


# Version for doubling width
# b = torch.cat(torch.unbind(a,3),1)
# b = torch.transpose(b, 1,2)

b.shape

shape of a is torch.Size([1, 2, 2, 2])


torch.Size([1, 1, 4, 2])

## Set Optimization Parameters

In [2]:
net_1 = CNNIL(axs = 'hw')
net_2 = CNNIL(axs = 'w')

# "... trained the CNN for 40 epochs, starting with a learning rate of 0.001 and decreasing
# the learning rate to 0.0001 after the first 20 epochs"
optimizer_1 = optim.Adam(net_1.parameters(), lr=0.001)
optimizer_2 = optim.Adam(net_2.parameters(), lr=0.001)

# They have a custom loss function that incorporates the final results and the result
# right after the upscaling step
# https://discuss.pytorch.org/t/custom-loss-functions/29387

def intermediate_loss(output_intermediate, output_final, target):
    mae_loss = nn.L1Loss() #Built in mean absolute error loss function
    loss = mae_loss(output_intermediate, target)+mae_loss(output_final, target)
    return loss


## Generate Data for Training

In [None]:
sr_train = sr_gen()

In [None]:
temp = sr_train.get_template()
temp["patch"] = 20

sr_train.save_template(temp)

sr_train.run(clear=True)

# Create Dataloader

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, sr_class):
        self.sr_class = sr_class

        # In case I forget to run match_altered before pulling the class
        if not sr_class.HR_files:
            sr_class.match_altered(update=True)

    def __len__(self):
        return len(self.sr_class.HR_files)

    def __getitem__(self, index):
        Y, X = self.sr_class.load_image_pair(index)
        X = torch.unsqueeze(torch.tensor(X, dtype=torch.float32),0)
        Y = torch.unsqueeze(torch.tensor(Y, dtype=torch.float32),0)

        return X,Y

In [None]:
params = {'batch_size': 64,
        'shuffle': True,
        'num_workers': 3}

training_set = Dataset(sr_train)
training_generator = torch.utils.data.DataLoader(training_set, **params)

# Training Loop

In [None]:
from tqdm import tqdm
import time

max_epochs = 20
save_rate = 5
epoch_adjust = 0
save_prefix = "./CNNIL_save_"

mean_loss = []

for epoch in tqdm(range(max_epochs)):
    losses = []

    ###### Test running this code where each epoch a new set of random images is made
    sr_train.run(clear=True)

    training_set = Dataset(sr_train)
    training_generator = torch.utils.data.DataLoader(training_set, **params)
    ######


    # Training
    count = 0
    for inp, goal in training_generator:
        optimizer.zero_grad()

        output = net(inp,2) # the 2 is the number of iterations in the LISTA network
        output = torch.clamp(output, 0, 255)

        loss = criterion(output,goal)
        loss.backward()
        optimizer.step()
        #print(f'loss = {loss.item()}')
        losses.append(loss.item())
        #print(f'mini-batch # {count}, mean loss = {sum(losses)/len(losses)}')
        count = count+1

    if (epoch % save_rate == 0) or epoch == (max_epochs-1):
        torch.save(net.state_dict(), f'{save_prefix}{epoch+epoch_adjust}.p')
    print(f'\n\n epoch {epoch}, loss mean: {sum(losses)/len(losses)}, loss: {min(losses)}-{max(losses)}\n')
    mean_loss.append(sum(losses)/len(losses))

    # Give computer time to cool down
    time.sleep(90)
