# SCN Model Training and Testing

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from gen_utils.sr_gen import sr_gen # Custom class for image generation/organization

In [2]:
class SCN(nn.Module):
    def __init__(self,sy,sg, train=True):
        super().__init__()
        C = 5
        L = 5

        Dx = torch.normal(0,1, size = (25,128))
        Dy = torch.normal(0,1, size = (100,128))
        I = torch.eye(128)

        self.conv = nn.Conv2d(1,100,9, bias = False, stride =1, padding = 6, padding_mode='reflect')
        self.mean2 = nn.Conv2d(1,1,13, bias = False, stride = 1, padding = 6, padding_mode='reflect')
        self.diffms = nn.Conv2d(1,25,9, bias=False, stride = 1, padding=6, padding_mode='reflect')

        self.wd = nn.Conv2d(100,128,1,bias = False, stride = 1)
        self.usd1 = nn.Conv2d(128, 128, 1, bias = False, stride=1)
        self.ud = nn.Conv2d(128,25,1,bias=False,stride=1)
        self.addp = nn.Conv2d(16,1,1, bias = False, stride = 1)

        if train: #If you are currently training the model
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13), requires_grad = False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T), requires_grad=True)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)), requires_grad=True)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx), requires_grad=True)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06, requires_grad=True)

        else:
            self.conv.weight = torch.nn.Parameter(torch.ones(100,1,9,9),requires_grad=False)
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13),requires_grad=False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T),requires_grad=False)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)),requires_grad=False)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx),requires_grad=False)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06,requires_grad=False)


    def forward(self, x, k, sy=9, sg=5):
        #print(f'input: {x.min()}-{x.max()}')
        x = x+0.1

        im_mean = self.mean2(x)
        # print(f'im_mean shape {im_mean.shape}')
        diffms = self.diffms(x)
        # print(f'diffms shape: {diffms.shape}')

        n, c, h, w = x.shape
        # y = torch.zeros(n, 100, h-8, w-8)
        x = self.conv(x)
        # print(f'post conv shape {x.shape}')
        #print(f'conv max {x.max()}')
        #x=x+1

        x = x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True)
        # print(f'post vector norm shape: {x.shape}')
        #print(f'postnorm max {x.max()}')

        x = self.wd(x)
        #print(f'conv wd {x.max()}')
        z = self.ShLU(x,1)
        #print(f'conv SHLU {x.max()}')

        # Go through LISTA
        for i in range(k):
            z = self.ShLU(self.usd1(z)+x,1)

        x = self.ud(z)
        #print(f'ud max {x.max()}')
        # print(f'post ud shape {x.shape}')
        x = (x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True))*torch.linalg.vector_norm(diffms, ord=2, dim=1, keepdim=True)*1.1
        # print(f'prereassembled x shape {x.shape}')
        x = self.reassemble2(x,im_mean,4)
        # print(f'reassembled x shape {x.shape}')
        x = self.addp(x)
        #print(f'x.reassemble.max = {x.max()}')
        x = x+im_mean

        #print(f'output: {x.min()}-{x.max()}')

        return x

    def reassemble2(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        
        # img_stack=torch.zeros(s,25,h,w)
        img_stack=torch.zeros(s,16,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    img_stack[q,filt,:,:] = x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack
    
    def create_diffms(self, kern_size, sy=5):
        diffms = torch.zeros(sy**2,1,kern_size,kern_size)
        
        neg = -1*(1/(sy**2))
        pos = 1+neg
        
        border = int((kern_size-sy)/2)
        base = torch.zeros(sy,sy)+neg
        cnt=0
        
        for i in range(sy**2):
            base = torch.zeros(sy**2)+neg
            base[cnt]=pos
            diffms[i,0,border:(kern_size-border),border:(kern_size-border)] = base.reshape([sy,sy])
            cnt+=1
        return diffms
    
    
    def create_gaus(self, kern_size, sy=9,std=2.15):
        n = torch.arange(0,sy)-(sy-1.0)/2.0
        sig2 = 2 * std * std
        gkern1d = torch.exp(-n ** 2 / sig2)
        gkern1d = gkern1d/torch.sum(gkern1d)
        #print(gkern1d.shape)
        gkern2d = torch.outer(gkern1d, gkern1d)
    

        # Wrap in zeros, if kern_size > sy
        gaussian_filter = torch.zeros(1,1,kern_size,kern_size)
        border = int((kern_size-sy)/2)
        gaussian_filter[0,0,border:(kern_size-border),border:(kern_size-border)] = gkern2d#(sy,std=std)
        #print(gaussian_filter.shape)
        return gaussian_filter
        
    
    def fixed_positions(self, tens, mult, sg):
        f, _ , h, w = tens.shape
        new_filt = torch.zeros(f*mult, 1, sg,sg)
        cnt = 0
        filt = 0
        
        for filt in range(f):
            for j in range((sg-w)+1):
                for i in range((sg-h)+1):
                    new_filt[cnt,0,i:i+h,j:j+w] = tens[filt]
                    cnt+=1
        return new_filt
    
    def expand_params(self,tens):
        return torch.unsqueeze(torch.unsqueeze(tens,2),3)
    
    def ShLU(self,a, th):
        return torch.sign(a)*torch.maximum(abs(a)-th, torch.tensor(0))

# Set Optimization Parameters

In [3]:
net = SCN(9,5,train=True)

#net.load_state_dict(torch.load('./MRI_save_29.p'))

criterion = nn.MSELoss()
optimizer = optim.SGD(
    [
        {"params": net.addp.parameters()},#, "lr": 0.0002, "momentum": 0.00005},
        {"params": net.conv.parameters()},#, "lr": 0.0003, "momentum": 0.0001},
        {"params": net.wd.parameters()},
        {"params": net.usd1.parameters()},
        {"params": net.ud.parameters()},
    ],
    #lr=0.0001, momentum=0.0001
    lr=0.00007, momentum = 0.0001
)

## Generate Data for Training

In [4]:
sr_train = sr_gen('./data/train/GT_corr/','./data/train/HR_corr_patches/','./data/train/LR_corr_patches/')

In [5]:
temp = sr_train.get_template()
temp["patch"]=44
temp["step"]=20
temp["translation_x"]=10
temp["translation_y"]=10
temp["rotation"] = 180
temp["scale"] = 1
sr_train.save_template(temp)

sr_train.run(clear=True)

Clearing existing output directories


## Create Training Dataset and Dataloader

In [6]:
# TODO: There's not reasing I can't combine the Dataset class and my custom class into one thing

class Dataset(torch.utils.data.Dataset):
    def __init__(self, sr_class):
        self.sr_class = sr_class

        # In case I forget to run match_altered before pulling the class
        if not sr_class.HR_files:
            sr_class.match_altered(update=True)

    def __len__(self):
        return len(self.sr_class.HR_files)

    def __getitem__(self, index):
        Y, X = self.sr_class.load_image_pair(index)
        X = torch.unsqueeze(torch.tensor(X, dtype=torch.float32),0)
        Y = torch.unsqueeze(torch.tensor(Y, dtype=torch.float32),0)

        return X, Y

In [7]:
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 3}

training_set = Dataset(sr_train)
training_generator = torch.utils.data.DataLoader(training_set, **params)

## Training Loop

In [8]:
from tqdm import tqdm
import time
# Loop over epochs

max_epochs = 40
save_rate = 5 #save a version of the model every 5 epochs
epoch_adjust = 0 #how much to add to the saved files in order to not overwrite
save_prefix = "./MRI_reflect_pad_save_"

mean_loss = []

for epoch in tqdm(range(max_epochs)):
    losses = []

    ###### Test running this code where each epoch a new set of random images is made
    sr_train.run(clear=True)

    training_set = Dataset(sr_train)
    training_generator = torch.utils.data.DataLoader(training_set, **params)
    ######


    # Training
    count = 0
    for inp, goal in training_generator:
        optimizer.zero_grad()

        output = net(inp,2) # the 2 is the number of iterations in the LISTA network
        output = torch.clamp(output, 0, 255)

        loss = criterion(output,goal)
        loss.backward()
        optimizer.step()
        #print(f'loss = {loss.item()}')
        losses.append(loss.item())
        #print(f'mini-batch # {count}, mean loss = {sum(losses)/len(losses)}')
        count = count+1

    if (epoch % save_rate == 0) or epoch == (max_epochs-1):
        torch.save(net.state_dict(), f'{save_prefix}{epoch+epoch_adjust}.p')
    print(f'\n\n epoch {epoch}, loss mean: {sum(losses)/len(losses)}, loss: {min(losses)}-{max(losses)}\n')
    mean_loss.append(sum(losses)/len(losses))

    # Give computer time to cool down
    time.sleep(90)

  0%|          | 0/40 [00:00<?, ?it/s]

Clearing existing output directories


 epoch 0, loss mean: 98.32311543551359, loss: 54.4627799987793-167.3396759033203



  2%|▎         | 1/40 [03:28<2:15:19, 208.19s/it]

Clearing existing output directories


 epoch 1, loss mean: 51.35435017672452, loss: 42.5849609375-68.5630874633789



  5%|▌         | 2/40 [06:57<2:12:15, 208.83s/it]

Clearing existing output directories


 epoch 2, loss mean: 41.78303077004173, loss: 34.81265640258789-50.28232192993164



  8%|▊         | 3/40 [10:25<2:08:32, 208.44s/it]

Clearing existing output directories


 epoch 3, loss mean: 36.33278135819869, loss: 33.42055892944336-41.9842643737793



 10%|█         | 4/40 [13:57<2:05:49, 209.71s/it]

Clearing existing output directories


 epoch 4, loss mean: 33.58856357227672, loss: 29.76911735534668-36.285865783691406



 12%|█▎        | 5/40 [17:25<2:01:59, 209.12s/it]

Clearing existing output directories


 epoch 5, loss mean: 31.985302665016867, loss: 27.908756256103516-37.955238342285156



 15%|█▌        | 6/40 [20:46<1:57:01, 206.53s/it]

Clearing existing output directories


 epoch 6, loss mean: 30.406118653037332, loss: 26.195032119750977-35.40159225463867



 18%|█▊        | 7/40 [24:07<1:52:31, 204.58s/it]

Clearing existing output directories


 epoch 7, loss mean: 30.59968558224765, loss: 26.082624435424805-37.712486267089844



 20%|██        | 8/40 [27:27<1:48:27, 203.36s/it]

Clearing existing output directories


 epoch 8, loss mean: 30.362801638516512, loss: 26.34787940979004-36.843509674072266



 22%|██▎       | 9/40 [30:47<1:44:26, 202.13s/it]

Clearing existing output directories


 epoch 9, loss mean: 29.200725728815254, loss: 24.18450927734375-33.6534423828125



 25%|██▌       | 10/40 [34:08<1:40:54, 201.80s/it]

Clearing existing output directories


 epoch 10, loss mean: 28.70171512256969, loss: 23.59297752380371-32.42100524902344



 28%|██▊       | 11/40 [37:30<1:37:30, 201.76s/it]

Clearing existing output directories


 epoch 11, loss mean: 29.119379997253418, loss: 25.405202865600586-33.913997650146484



 30%|███       | 12/40 [40:49<1:33:48, 201.00s/it]

Clearing existing output directories


 epoch 12, loss mean: 29.129417072642934, loss: 24.721755981445312-31.894960403442383



 32%|███▎      | 13/40 [44:09<1:30:18, 200.69s/it]

Clearing existing output directories


 epoch 13, loss mean: 27.569858290932395, loss: 23.46613311767578-32.30982208251953



 35%|███▌      | 14/40 [47:28<1:26:46, 200.26s/it]

Clearing existing output directories


 epoch 14, loss mean: 29.800518642772328, loss: 26.081289291381836-33.39929962158203



 38%|███▊      | 15/40 [50:47<1:23:18, 199.94s/it]

Clearing existing output directories


 epoch 15, loss mean: 29.03403230146928, loss: 25.16359519958496-33.611568450927734



 40%|████      | 16/40 [54:06<1:19:51, 199.63s/it]

Clearing existing output directories


 epoch 16, loss mean: 27.83819432692094, loss: 21.973224639892578-34.02215576171875



 42%|████▎     | 17/40 [57:24<1:16:18, 199.07s/it]

Clearing existing output directories


 epoch 17, loss mean: 26.745600353587758, loss: 22.799522399902344-30.931591033935547



 45%|████▌     | 18/40 [1:00:43<1:12:59, 199.09s/it]

Clearing existing output directories


 epoch 18, loss mean: 27.71150849082253, loss: 24.253095626831055-32.837615966796875



 48%|████▊     | 19/40 [1:04:03<1:09:43, 199.22s/it]

Clearing existing output directories


 epoch 19, loss mean: 27.18992978876287, loss: 23.9979190826416-31.281557083129883



 50%|█████     | 20/40 [1:07:21<1:06:16, 198.80s/it]

Clearing existing output directories


 epoch 20, loss mean: 26.32117020000111, loss: 23.067411422729492-30.41972541809082



 52%|█████▎    | 21/40 [1:10:39<1:02:55, 198.72s/it]

Clearing existing output directories


 epoch 21, loss mean: 26.325577995993875, loss: 24.279020309448242-28.89850425720215



 55%|█████▌    | 22/40 [1:13:59<59:41, 198.95s/it]  

Clearing existing output directories


 epoch 22, loss mean: 27.708439480174672, loss: 23.66593360900879-31.992412567138672



 57%|█████▊    | 23/40 [1:17:19<56:28, 199.30s/it]

Clearing existing output directories


 epoch 23, loss mean: 28.123160535638984, loss: 23.07906150817871-31.367679595947266



 60%|██████    | 24/40 [1:20:38<53:08, 199.28s/it]

Clearing existing output directories


 epoch 24, loss mean: 27.850334600968793, loss: 24.9620304107666-30.15374183654785



 62%|██████▎   | 25/40 [1:23:58<49:51, 199.46s/it]

Clearing existing output directories


 epoch 25, loss mean: 27.12289506738836, loss: 23.964746475219727-30.789180755615234



 65%|██████▌   | 26/40 [1:27:16<46:28, 199.21s/it]

Clearing existing output directories


 epoch 26, loss mean: 25.82103018327193, loss: 21.289731979370117-29.91715431213379



 68%|██████▊   | 27/40 [1:30:36<43:11, 199.32s/it]

Clearing existing output directories


 epoch 27, loss mean: 25.821901321411133, loss: 22.979354858398438-31.2678165435791



 70%|███████   | 28/40 [1:33:57<39:56, 199.71s/it]

Clearing existing output directories


 epoch 28, loss mean: 25.819851875305176, loss: 22.628971099853516-28.7922306060791



 72%|███████▎  | 29/40 [1:37:18<36:41, 200.18s/it]

Clearing existing output directories


 epoch 29, loss mean: 25.583398298783735, loss: 22.019546508789062-28.75263023376465



 75%|███████▌  | 30/40 [1:40:39<33:24, 200.45s/it]

Clearing existing output directories


 epoch 30, loss mean: 25.906915577975187, loss: 23.25433349609375-28.14934539794922



 78%|███████▊  | 31/40 [1:43:58<30:01, 200.15s/it]

Clearing existing output directories


 epoch 31, loss mean: 27.709976196289062, loss: 23.960636138916016-32.19546127319336



 80%|████████  | 32/40 [1:47:18<26:39, 199.90s/it]

Clearing existing output directories


 epoch 32, loss mean: 25.40254020690918, loss: 20.6865177154541-28.738365173339844



 82%|████████▎ | 33/40 [1:50:37<23:17, 199.61s/it]

Clearing existing output directories


 epoch 33, loss mean: 25.636878880587492, loss: 22.34256362915039-29.13888931274414



 85%|████████▌ | 34/40 [1:53:57<19:58, 199.71s/it]

Clearing existing output directories


 epoch 34, loss mean: 25.631733547557484, loss: 21.948619842529297-28.845075607299805



 88%|████████▊ | 35/40 [1:57:17<16:40, 200.07s/it]

Clearing existing output directories


 epoch 35, loss mean: 26.581252444874156, loss: 23.607032775878906-31.140268325805664



 90%|█████████ | 36/40 [2:00:39<13:22, 200.57s/it]

Clearing existing output directories


 epoch 36, loss mean: 25.556747089732777, loss: 22.92979621887207-28.44339942932129



 92%|█████████▎| 37/40 [2:03:59<10:01, 200.47s/it]

Clearing existing output directories


 epoch 37, loss mean: 25.338418440385297, loss: 22.199262619018555-29.149282455444336



 95%|█████████▌| 38/40 [2:07:21<06:41, 200.89s/it]

Clearing existing output directories


 epoch 38, loss mean: 25.892704963684082, loss: 23.212108612060547-29.126216888427734



 98%|█████████▊| 39/40 [2:10:42<03:20, 200.83s/it]

Clearing existing output directories


 epoch 39, loss mean: 26.585525946183637, loss: 21.103351593017578-30.358949661254883



100%|██████████| 40/40 [2:14:03<00:00, 201.10s/it]


In [24]:
print(mean_loss)

[24.94681098244407]


## Testing Loop

In [10]:
# Generate testing data
sr_test = sr_gen('./data/test/GT_corr/','./data/test/HR_corr/','./data/test/LR_corr/')

In [11]:
temp = sr_test.get_template()
temp["patch"]=False #44
# temp["step"]=20
# temp["rotation"] = 180
temp["scale"] = 1
temp["rotation"]=180
sr_test.save_template(temp)

sr_test.run(clear=True)

Clearing existing output directories
Saving image: 155_rot91
Saving image: 140_rot67
Saving image: 145_rot-144
Saving image: 160_rot-1
Saving image: 150_rot150


### OPTIONAL: Load previously trained model

In [15]:
net = SCN(9,5,train=False) #Switch to True if you want to keep training
net.load_state_dict(torch.load('./MRI_reflect_pad_save_39.p'))

<All keys matched successfully>

In [16]:
# Load matched images
im_hr, im_lr = sr_test.match_altered(update = True, paths=True)

save_pred = False # Whether to save the images created by the network during testing
save_dir = "./"
if save_pred:
    os.makedirs(save_dir, exist_ok=True)


with torch.no_grad():
    for i in im_hr: #range(len(im_hr)):

        # Load in image information
        im_h, im_l = sr_test.load_image_pair(i)

        # Take low resolution and upscale using bicubic interpolation
        # (which has already been done due to the image generation process)
        # Thus im_l is the bicubic interpolation to compare to...

        # Use SR model on low resolution image
        im_h_sr = net(torch.unsqueeze(torch.unsqueeze(torch.tensor(im_l, dtype=torch.float32),0),0),2)

        # Calculate PSNR for bicubic
        im_l = np.rint( np.clip(im_l, 0, 255))
        im_h = np.rint( np.clip(im_h, 0, 255))
        diff = im_l - im_h
        rmse = np.sqrt((diff**2).mean())
        psnr = 20*np.log10(255.0/rmse)

        print(f'bicubic evaluation for {i}: rms={rmse}, psnr={psnr}')

        # Calculate PSNR for SR
        im_h_sr = np.rint( np.clip(im_h_sr, 0, 255))
        im_h = np.rint( np.clip(im_h, 0, 255))
        diff = im_h_sr - im_h
        rmse = np.sqrt((diff**2).mean())
        psnr = 20*np.log10(255.0/rmse)
        print(f'SR evaluation for {i}: rms={rmse}, psnr={psnr}')

        if save_pred:
            img_name = os.path.splitext(os.path.basename(i))[0]
            Image.fromarray(np.rint(im_h_sr).astype(np.uint8)).save(f"{save_dir}/{img_name}_SR.png")


HR and LR file locations updated
bicubic evaluation for ./data/test/HR_corr/155_rot91.png: rms=4.91015625, psnr=34.30889736119694
SR evaluation for ./data/test/HR_corr/155_rot91.png: rms=4.818031311035156, psnr=34.473411560058594
bicubic evaluation for ./data/test/HR_corr/145_rot-144.png: rms=4.546875, psnr=34.976543308638696
SR evaluation for ./data/test/HR_corr/145_rot-144.png: rms=4.458677768707275, psnr=35.14668273925781
bicubic evaluation for ./data/test/HR_corr/140_rot67.png: rms=3.361328125, psnr=37.60058542164452
SR evaluation for ./data/test/HR_corr/140_rot67.png: rms=3.3172366619110107, psnr=37.715274810791016
bicubic evaluation for ./data/test/HR_corr/150_rot150.png: rms=5.20703125, psnr=33.79899992663891
SR evaluation for ./data/test/HR_corr/150_rot150.png: rms=5.131485939025879, psnr=33.925941467285156
bicubic evaluation for ./data/test/HR_corr/160_rot-1.png: rms=4.40234375, psnr=35.257124593993964
SR evaluation for ./data/test/HR_corr/160_rot-1.png: rms=4.329180717468262,