# [Robust Single Image Super-Resolution via Deep Networks with Sparse Prior](https://ieeexplore.ieee.org/document/7466062)

This notebook is a replication/exploration of the paper linked above. Data was tested on the ABIDEII-BNI1 anatomical scans (subjects 29006 - 29011 for training, 29012 - 29015 for testing). This data can be found [here](http://fcon_1000.projects.nitrc.org/indi/abide/abide_II.html) in the Barrow Neurological Institute `Scan Data` link.

## Setup
There is an assumed organization of the MRI files from ABIDEII-BNI1 in this exercise with respect to this folder. If you have these files in the correct folder, you should receive the same (or similar) results. The files are not provided in this repository for file size reasons. If this is not the case for you, change the references to these files in `SrGen` in the following code to the correct locations.

```
../data/CNNIL_nifti/
                Raw_train/
                    subject_29006.nii
                    subject_29007.nii
                    subject_29008.nii
                    subject_29009.nii
                    subject_29010.nii
                    subject_29011.nii
                Raw_test/
                    subject_29012.nii
                    subject_29013.nii
                    subject_29014.nii
                    subject_29015.nii
```

### Imports
Like all good python scripts, we import a couple libraries. Note the importing of my own custom class `SrGen` for data loading/saving/organization. That is also in this repository, see the `/gen_utils/SrGen.py` for the source code.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import sys
sys.path.append('..')
from gen_utils.SrGen import SrGen

## Define Model

In [None]:
class SCN(nn.Module):
    def __init__(self,sy,sg):
        super().__init__()
        C = 5
        L = 5

        Dx = torch.normal(0,1, size = (25,128))
        Dy = torch.normal(0,1, size = (100,128))
        I = torch.eye(128)

        self.conv = nn.Conv2d(1,100,9, bias = False, stride =1, padding = 6, padding_mode='reflect')
        self.mean2 = nn.Conv2d(1,1,13, bias = False, stride = 1, padding = 6, padding_mode='reflect')
        self.diffms = nn.Conv2d(1,25,9, bias=False, stride = 1, padding=6, padding_mode='reflect')

        self.wd = nn.Conv2d(100,128,1,bias = False, stride = 1)
        self.usd1 = nn.Conv2d(128, 128, 1, bias = False, stride=1)
        self.ud = nn.Conv2d(128,25,1,bias=False,stride=1)
        self.addp = nn.Conv2d(16,1,1, bias = False, stride = 1)

        self.mean2.weight = torch.nn.Parameter(self.create_gaus(13), requires_grad = False)
        self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
        self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T), requires_grad=True)
        self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)), requires_grad=True)
        self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx), requires_grad=True)
        self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06, requires_grad=True)


    def forward(self, x, k, sy=9, sg=5):
        #print(f'input: {x.min()}-{x.max()}')
        x = x+0.1

        im_mean = self.mean2(x)
        # print(f'im_mean shape {im_mean.shape}')
        diffms = self.diffms(x)
        # print(f'diffms shape: {diffms.shape}')

        n, c, h, w = x.shape
        # y = torch.zeros(n, 100, h-8, w-8)
        x = self.conv(x)
        # print(f'post conv shape {x.shape}')
        #print(f'conv max {x.max()}')
        #x=x+1

        x = x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True)
        # print(f'post vector norm shape: {x.shape}')
        #print(f'postnorm max {x.max()}')

        x = self.wd(x)
        #print(f'conv wd {x.max()}')
        z = self.ShLU(x,1)
        #print(f'conv SHLU {x.max()}')

        # Go through LISTA
        for i in range(k):
            z = self.ShLU(self.usd1(z)+x,1)

        x = self.ud(z)
        #print(f'ud max {x.max()}')
        # print(f'post ud shape {x.shape}')
        x = (x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True))*torch.linalg.vector_norm(diffms, ord=2, dim=1, keepdim=True)*1.1
        # print(f'prereassembled x shape {x.shape}')
        x = self.reassemble2(x,im_mean,4)
        # print(f'reassembled x shape {x.shape}')
        x = self.addp(x)
        #print(f'x.reassemble.max = {x.max()}')
        x = x+im_mean

        #print(f'output: {x.min()}-{x.max()}')

        return x

    def reassemble2(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        
        # img_stack=torch.zeros(s,25,h,w)
        img_stack=torch.zeros(s,16,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    img_stack[q,filt,:,:] = x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack
    
    def create_diffms(self, kern_size, sy=5):
        diffms = torch.zeros(sy**2,1,kern_size,kern_size)
        
        neg = -1*(1/(sy**2))
        pos = 1+neg
        
        border = int((kern_size-sy)/2)
        base = torch.zeros(sy,sy)+neg
        cnt=0
        
        for i in range(sy**2):
            base = torch.zeros(sy**2)+neg
            base[cnt]=pos
            diffms[i,0,border:(kern_size-border),border:(kern_size-border)] = base.reshape([sy,sy])
            cnt+=1
        return diffms
    
    
    def create_gaus(self, kern_size, sy=9,std=2.15):
        n = torch.arange(0,sy)-(sy-1.0)/2.0
        sig2 = 2 * std * std
        gkern1d = torch.exp(-n ** 2 / sig2)
        gkern1d = gkern1d/torch.sum(gkern1d)
        #print(gkern1d.shape)
        gkern2d = torch.outer(gkern1d, gkern1d)
    

        # Wrap in zeros, if kern_size > sy
        gaussian_filter = torch.zeros(1,1,kern_size,kern_size)
        border = int((kern_size-sy)/2)
        gaussian_filter[0,0,border:(kern_size-border),border:(kern_size-border)] = gkern2d#(sy,std=std)
        #print(gaussian_filter.shape)
        return gaussian_filter
        
    
    def fixed_positions(self, tens, mult, sg):
        f, _ , h, w = tens.shape
        new_filt = torch.zeros(f*mult, 1, sg,sg)
        cnt = 0
        filt = 0
        
        for filt in range(f):
            for j in range((sg-w)+1):
                for i in range((sg-h)+1):
                    new_filt[cnt,0,i:i+h,j:j+w] = tens[filt]
                    cnt+=1
        return new_filt
    
    def expand_params(self,tens):
        return torch.unsqueeze(torch.unsqueeze(tens,2),3)
    
    def ShLU(self,a, th):
        return torch.sign(a)*torch.maximum(abs(a)-th, torch.tensor(0))

## Set Optimization Parameters
DESCRIPTION

In [None]:
net = SCN(9,5,train=True)

#net.load_state_dict(torch.load('./MRI_save_29.p'))

criterion = nn.MSELoss()
optimizer = optim.SGD(
    [
        {"params": net.addp.parameters()},#, "lr": 0.0002, "momentum": 0.00005},
        {"params": net.conv.parameters()},#, "lr": 0.0003, "momentum": 0.0001},
        {"params": net.wd.parameters()},
        {"params": net.usd1.parameters()},
        {"params": net.ud.parameters()},
    ],
    #lr=0.0001, momentum=0.0001
    lr=0.00007, momentum = 0.0001
)

## Generate Data for Training

In [None]:
sr_train = SrGen('./data/train/GT_corr/','./data/train/HR_corr_patches/','./data/train/LR_corr_patches/')

temp = sr_train.get_template()
temp["patch"]=44
temp["step"]=20
temp["translation_x"]=10
temp["translation_y"]=10
temp["rotation"] = 180
temp["scale"] = 1
sr_train.save_template(temp)

sr_train.run(clear=True)

sr_train.match_altered(update=True, paths=False, sort=False)

## Create Dataloader

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, sr_class):
        self.sr_class = sr_class

        # In case I forget to run match_altered before pulling the class
        if not sr_class.HR_files:
            sr_class.match_altered(update=True)

    def __len__(self):
        return len(self.sr_class.HR_files)

    def __getitem__(self, index):
        Y, X = self.sr_class.load_image_pair(index)
        X = torch.unsqueeze(torch.tensor(X, dtype=torch.float32),0)
        Y = torch.unsqueeze(torch.tensor(Y, dtype=torch.float32),0)

        return X, Y

## Create Dataloader for Training

In [None]:
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 3}

training_set = Dataset(sr_train)
training_generator = torch.utils.data.DataLoader(training_set, **params)

## Training Loop

In [None]:
from tqdm import tqdm
import time
# Loop over epochs

max_epochs = 40
save_rate = 5 #save a version of the model every 5 epochs
epoch_adjust = 0 #how much to add to the saved files in order to not overwrite
save_prefix = "./MRI_reflect_pad_save_"

mean_loss = []

for epoch in tqdm(range(max_epochs)):
    losses = []

    ###### Test running this code where each epoch a new set of random images is made
    sr_train.run(clear=True)

    training_set = Dataset(sr_train)
    training_generator = torch.utils.data.DataLoader(training_set, **params)
    ######


    # Training
    count = 0
    for inp, goal in training_generator:
        optimizer.zero_grad()

        output = net(inp,2) # the 2 is the number of iterations in the LISTA network
        output = torch.clamp(output, 0, 255)

        loss = criterion(output,goal)
        loss.backward()
        optimizer.step()
        #print(f'loss = {loss.item()}')
        losses.append(loss.item())
        #print(f'mini-batch # {count}, mean loss = {sum(losses)/len(losses)}')
        count = count+1

    if (epoch % save_rate == 0) or epoch == (max_epochs-1):
        torch.save(net.state_dict(), f'{save_prefix}{epoch+epoch_adjust}.p')
    print(f'\n\n epoch {epoch}, loss mean: {sum(losses)/len(losses)}, loss: {min(losses)}-{max(losses)}\n')
    mean_loss.append(sum(losses)/len(losses))

    # Give computer time to cool down
    time.sleep(90)

### Sanity Check

In [None]:
fig, axs = plt.subplots(1, 2)
axs[0].plot([x for x in range(len(mean_loss["CNNIL_1"]))],mean_loss["CNNIL_1"])
axs[0].set_title('Mean Loss for First CNN Block')

## Testing Model

In [None]:
# Generate testing data
sr_test = SrGen('./data/test/GT_corr/','./data/test/HR_corr/','./data/test/LR_corr/')


temp = sr_test.get_template()
temp["patch"]=False #44
# temp["step"]=20
# temp["rotation"] = 180
temp["scale"] = 1
temp["rotation"]=180
sr_test.save_template(temp)

sr_test.run(clear=True)

sr_test.match_altered(update=True, paths=False, sort=False)

In [None]:
params_t = {'batch_size': 1,
        'shuffle': False,
        'num_workers': 2}

testing_set = Dataset(sr_test, axs = 'hw')
testing_generator = torch.utils.data.DataLoader(testing_set, **params_t)

In [None]:
# Load matched images
im_hr, im_lr = sr_test.match_altered(update = True, paths=True)

save_pred = False # Whether to save the images created by the network during testing
save_dir = "./"
if save_pred:
    os.makedirs(save_dir, exist_ok=True)


with torch.no_grad():
    for i in im_hr: #range(len(im_hr)):

        # Load in image information
        im_h, im_l = sr_test.load_image_pair(i)

        # Take low resolution and upscale using bicubic interpolation
        # (which has already been done due to the image generation process)
        # Thus im_l is the bicubic interpolation to compare to...

        # Use SR model on low resolution image
        im_h_sr = net(torch.unsqueeze(torch.unsqueeze(torch.tensor(im_l, dtype=torch.float32),0),0),2)

        # Calculate PSNR for bicubic
        im_l = np.rint( np.clip(im_l, 0, 255))
        im_h = np.rint( np.clip(im_h, 0, 255))
        diff = im_l - im_h
        rmse = np.sqrt((diff**2).mean())
        psnr = 20*np.log10(255.0/rmse)

        print(f'bicubic evaluation for {i}: rms={rmse}, psnr={psnr}')

        # Calculate PSNR for SR
        im_h_sr = np.rint( np.clip(im_h_sr, 0, 255))
        im_h = np.rint( np.clip(im_h, 0, 255))
        diff = im_h_sr - im_h
        rmse = np.sqrt((diff**2).mean())
        psnr = 20*np.log10(255.0/rmse)
        print(f'SR evaluation for {i}: rms={rmse}, psnr={psnr}')

        if save_pred:
            img_name = os.path.splitext(os.path.basename(i))[0]
            Image.fromarray(np.rint(im_h_sr).astype(np.uint8)).save(f"{save_dir}/{img_name}_SR.png")

In [None]:
# Sanity check, plot a histogram of how well the SR image performed compared to the Bicubic comparison
fig, axs = plt.subplots(1, 2)
axs[0].hist([float(x) for x in comp['psnr']], bins=10)
axs[0].set_title('PSNR: SR - BiC')

## Final Application

In [None]:
# Generate testing data for both axes
sr = SrGen('../data/CNNIL_nifti/Raw_test/','../data/CNNIL_nifti/Full_test/','../data/CNNIL_nifti/LR_Full_test/')

temp = sr.get_template()
temp['out_type'] = 'nii'
temp['resolution'] = [2,2,2]
temp['translation'] = [0, 0, 0]
temp['rotation'] = [0, 0, 0]
temp['keep_blank'] = False
temp['same_size'] = False

sr.set_template(temp)

#sr.run(clear=True, save=True)
sr.match_altered(update=True, paths=False, sort=False)


params_t = {'batch_size': 1,
        'shuffle': False,
        'num_workers': 2}

full_img_generator = Dataset(sr, axs = 'hw'### Use the models for SR of MRI images and calculate PSNR
# Load trained models:
net_1.load_state_dict(torch.load('CNNIL_save_network1_39.p'))
net_2.load_state_dict(torch.load('CNNIL_save_network2_39.p'))

# Where or not to save the output SR images, if None then don't save any:
save_pre = './SR_images'


with torch.no_grad():
    comp={'psnr' : [], 'rmse' : []}
    for idx, [im_l, im_h] in enumerate(full_img_generator):

        output_1, output_2 = net_1(torch.transpose(im_l,0,1))
        output_1, output_2 = net_2(torch.transpose(output_2,0,3))
        output_2 = torch.transpose(torch.squeeze(output_2,1),2,0)

        # Upscale im_l to the same size as im_h
        im_l = torch.tensor(resize(im_l, im_h.shape, order=1, mode = 'symmetric'))

        # Calculate PSNR for bicubic
        diff = im_l - im_h
        rmse_b = np.sqrt((diff**2).mean())
        psnr_b = 20*np.log10(im_h.max()/rmse_b)


        # Calculate PSNR for SR
        diff = output_2 - im_h
        rmse_s = np.sqrt((diff**2).mean())
        psnr_s = 20*np.log10(im_h.max()/rmse_s)

        comp['psnr'].append(psnr_s-psnr_b)
        comp['rmse'].append(rmse_s-rmse_b)

        if save_pre:
            # Save the resulting images as .nii files
            os.makedirs(save_pre,exist_ok=True)
            sr.save_image(f'{save_pre}/{sr.LR_files[idx].split("/")[-1]}', output_2.numpy())

fig, axs = plt.subplots(1, 2)
axs[0].hist([float(x) for x in comp['psnr']], bins=10)
axs[0].set_title('PSNR: SR - BiC')
axs[1].hist([float(x) for x in comp['rmse']], bins=10)
axs[1].set_title('RMSE: SR - BiC')
### Compare Slices from SR and HR image
fig, axs = plt.subplots(1, 2)
axs[0].imshow(torch.squeeze(output_2[40,:,:],0))
axs[0].set_title('Super Resolution')
axs[1].imshow(sr.load_image(sr.HR_files[-1])[40,:,:])
axs[1].set_title('Truth')
## Conclusion

There we have it! A nice version of the pipeline discussed using Pytorch!

Potential directions for future development:
- Determine how well the model works with different MRI collection parameters
- Train it on other medical image types aside from MRI)