# Convolutional Neural Networks with Intermediate Loss for 3D Super-Resolution of CT and MRI Scans

This notebook is a replication/exploration of the paper listed above.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
import sys
sys.path.append('..') # Stupid thing Python makes you do to import from a sibling directory
from gen_utils.SrGen import SrGen # Custom class for image generation

## Define model

In [165]:
class CNNIL(nn.Module):
    def __init__(self, upscale=2, axs = 'hw'):
        super().__init__()
        self.axs = axs

        if axs == 'hw':
            self.rel = nn.ReLU()
            self.conv1 = nn.Conv2d(1,32,3, padding='same', bias=False)
            self.conv2 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv3 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv4 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv5 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv6 = nn.Conv2d(32,4,3, padding='same', bias=False)
            
            # Upscale step occurs here

            self.conv7 = nn.Conv2d(1,32,3, padding='same', bias=False)
            self.conv8 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv9 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv10 = nn.Conv2d(32,1,3, padding='same', bias = False)
        elif axs == 'h' or axs == 'w':
            self.conv1 = nn.Conv2d(1,32,3, padding='same', bias=False)
            self.conv2 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv3 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv4 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv5 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv6 = nn.Conv2d(32,2,3, padding='same', bias=False)
            
            # Upscale step occurs here

            self.conv7 = nn.Conv2d(1,32,3, padding='same', bias=False)
            self.conv8 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv9 = nn.Conv2d(32,32,3, padding='same', bias=False)
            self.conv10 = nn.Conv2d(32,1,3, padding='same', bias = False)

    def forward(self, x):

        x = self.rel(self.conv1(x))
        
        x_l = self.rel(self.conv2(x))
        x_l = self.rel(self.conv3(x_l))
        x_l = self.rel(self.conv4(x_l+x))
        x_l = self.rel(self.conv5(x_l))
        x_l = self.rel(self.conv6(x_l+x))

        x_l = self.kern_upscale(x_l, axs = self.axs)

        x = self.rel(self.conv7(x_l))
        x_h = self.rel(self.conv8(x))
        x_h = self.rel(self.conv9(x_h))
        x_h = self.rel(self.conv10(x_h+x))

        return x_l, x_h #Return both results for the loss function

    @staticmethod
    def kern_upscale(x, axs='hw'):
        # Function to do the unique upscaling pattern they propose in
        # the paper
        #TODO: have them input a tuple for scale of the dimensions they wish to expand along
        s, c, h, w = [int(_) for _ in list(x.shape)]

        if axs == 'hw':
            c = int(c/2)
            x_up = torch.cat(torch.unbind(x,1),2)
            x_up = torch.reshape(x_up,(s,1,h*c,w*c))
            x_up = torch.transpose(x_up,2,3)
            x_up = torch.cat(torch.split(x_up,2,2),3)
            x_up = torch.transpose(torch.reshape(x_up,(s,1,h*c,w*c)),2,3)
        elif axs == 'h':
            x_up = torch.cat(torch.unbind(x,2),1)
            #x_up = torch.unsqueeze(x_up,0)
        elif axs == 'w':
            x_up = torch.cat(torch.unbind(x,3),1)
            x_up = torch.transpose(x_up,1,2)
            #x_up = torch.unsqueeze(x_up,0)
        else:
            print('No valid scaling dimension selected, returning False')
            x_up = False

        return x_up


In [2]:
# Testing of kern_upscale function with only 3D tensor

a = torch.tensor([[[1,2],[3,4]],
                [[5,6],[7,8]],
                [[9,10],[11,12]],
                [[13,14],[15,16]]])
a.shape

torch.Size([4, 2, 2])

In [3]:
# Concatinate each of the layers next to eachother
print(a.shape)
b = torch.cat(torch.unbind(a),1)
print(b.shape)
b = torch.reshape(b,(4,4))
print(b.shape)
b = torch.transpose(b,0,1)
b = torch.cat(torch.split(b,2,0),1)
torch.transpose(torch.reshape(b,(4,4)),0,1)

torch.Size([4, 2, 2])
torch.Size([2, 8])
torch.Size([4, 4])


tensor([[ 1,  5,  2,  6],
        [ 9, 13, 10, 14],
        [ 3,  7,  4,  8],
        [11, 15, 12, 16]])

In [4]:
# Testing of above kern_upscale function with 4D tensor (what the model will acually use)
a = torch.tensor([[[[1,2],[3,4]],
                [[5,6],[7,8]],
                [[9,10],[11,12]],
                [[13,14],[15,16]]]])
print(f'shape of a is {a.shape}')

b = torch.cat(torch.unbind(a,1),2)
b = torch.reshape(b,(1,1,4,4))
b = torch.transpose(b,2,3)
b = torch.cat(torch.split(b,2,2),3)
b = torch.transpose(torch.reshape(b,(1,1,4,4)),2,3)
b

shape of a is torch.Size([1, 4, 2, 2])


tensor([[[[ 1,  5,  2,  6],
          [ 9, 13, 10, 14],
          [ 3,  7,  4,  8],
          [11, 15, 12, 16]]]])

In [27]:
# Upscale along just one axis
a = torch.tensor([[[[1,2],[3,4]],
                [[5,6],[7,8]]]])
print(f'shape of a is {a.shape}')

# Version for doubling height
b = torch.cat(torch.unbind(a,2),2)
b = torch.reshape(b,(1,1,2,4))
b = torch.cat(torch.split(b,2,3),2)
# b = torch.unsqueeze(torch.cat(torch.unbind(a,2),1),0)


# Version for doubling width
# b = torch.cat(torch.unbind(a,3),1)
# b = torch.transpose(b, 1,2)

b.shape

shape of a is torch.Size([1, 2, 2, 2])


torch.Size([1, 1, 4, 2])

## Set Optimization Parameters

In [166]:
net_1 = CNNIL(axs = 'hw')
net_2 = CNNIL(axs = 'w')

# "... trained the CNN for 40 epochs, starting with a learning rate of 0.001 and decreasing
# the learning rate to 0.0001 after the first 20 epochs"
optimizer_1 = optim.Adam(net_1.parameters(), lr=0.001)
optimizer_2 = optim.Adam(net_2.parameters(), lr=0.001)

# They have a custom loss function that incorporates the final results and the result
# right after the upscaling step
# https://discuss.pytorch.org/t/custom-loss-functions/29387

def intermediate_loss(output_intermediate, output_final, target):
    mae_loss = nn.L1Loss() #Built in mean absolute error loss function
    loss = mae_loss(output_intermediate, target)+mae_loss(output_final, target)
    return loss


## Generate Data for Training

In [167]:
sr_train = SrGen('../data/CNNIL_nifti/Raw/','../data/CNNIL_nifti/HR_patches_ax/','../data/CNNIL_nifti/LR_patches_ax/')

In [168]:
sr_train.match_altered(update=True, paths=False, sort=False)

HR and LR file locations updated


# Create Dataloader

In [169]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, sr_class):
        self.sr_class = sr_class

        # In case I forget to run match_altered before pulling the class
        if not sr_class.HR_files:
            sr_class.match_altered(update=True)

    def __len__(self):
        return len(self.sr_class.HR_files)

    def __getitem__(self, index):
        Y, X = self.sr_class.load_image_pair(index)
        X = torch.unsqueeze(torch.squeeze(torch.tensor(X, dtype=torch.float32),-1),0)
        Y = torch.unsqueeze(torch.squeeze(torch.tensor(Y, dtype=torch.float32),-1),0)

        return X,Y

In [170]:
params = {'batch_size': 64,
        'shuffle': True,
        'num_workers': 3}

training_set = Dataset(sr_train)
training_generator = torch.utils.data.DataLoader(training_set, **params)

# Training Loop

In [171]:
from tqdm import tqdm
import time

max_epochs = 2
save_rate = 5
epoch_adjust = 0
save_prefix = "./CNNIL_save_"

mean_loss = []

for epoch in tqdm(range(max_epochs)):
    losses = []

    ###### Test running this code where each epoch a new set of random images is made
    # sr_train.run(clear=True)

    # training_set = Dataset(sr_train)
    # training_generator = torch.utils.data.DataLoader(training_set, **params)
    ######


    # Training
    count = 0
    for inp, goal in training_generator:
        optimizer_1.zero_grad()
        print(inp.shape)

        output_1, output_2 = net_1(inp) # the 2 is the number of iterations in the LISTA network
        #output = torch.clamp(output, 0, 255)

        loss = intermediate_loss(output_1,output_2,goal)
        loss.backward()
        optimizer_1.step()
        #print(f'loss = {loss.item()}')
        losses.append(loss.item())
        print(f'mini-batch # {count}, mean loss = {sum(losses)/len(losses)}')
        count = count+1

    # if (epoch % save_rate == 0) or epoch == (max_epochs-1):
    #     torch.save(net.state_dict(), f'{save_prefix}{epoch+epoch_adjust}.p')
    # print(f'\n\n epoch {epoch}, loss mean: {sum(losses)/len(losses)}, loss: {min(losses)}-{max(losses)}\n')
    # mean_loss.append(sum(losses)/len(losses))

    # Give computer time to cool down
    time.sleep(20)


  0%|          | 0/2 [00:00<?, ?it/s]

torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([

 50%|█████     | 1/2 [00:13<00:13, 13.24s/it]

torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([24, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([

100%|██████████| 2/2 [00:23<00:00, 11.76s/it]

torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([64, 1, 7, 7])
14 has type <class 'int'>
torch.Size([24, 1, 7, 7])
14 has type <class 'int'>





In [172]:
losses

[18.611053466796875,
 16.44835662841797,
 18.057254791259766,
 15.806312561035156,
 14.840869903564453,
 15.104841232299805,
 18.200843811035156,
 17.35634994506836,
 16.777503967285156,
 16.564451217651367,
 17.927997589111328,
 16.7811222076416,
 18.046438217163086,
 18.86160659790039,
 15.222896575927734,
 16.464311599731445,
 16.976226806640625,
 16.133575439453125,
 14.355478286743164,
 16.752981185913086,
 17.30026626586914,
 16.056882858276367,
 16.375024795532227,
 15.206205368041992,
 17.256237030029297,
 18.23826789855957,
 13.501985549926758,
 16.928953170776367,
 16.33367347717285,
 16.995126724243164,
 15.832449913024902,
 17.739089965820312,
 16.742361068725586,
 16.124393463134766,
 17.364212036132812,
 17.047029495239258,
 15.539438247680664,
 16.265506744384766,
 13.967340469360352,
 14.469034194946289,
 14.403247833251953,
 15.60812759399414,
 15.245738983154297,
 17.989978790283203,
 16.227439880371094,
 14.53143310546875,
 16.348905563354492,
 15.271605491638184,
 1