# SCN Model Training and Testing

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from gen_utils.sr_gen import sr_gen # Custom class for image generation/organization

In [None]:
class SCN(nn.Module):
    def __init__(self,sy,sg, model_file=False, train=True):
        super().__init__()
        C = 5
        L = 5

        Dx = torch.normal(0,1, size = (25,128))
        Dy = torch.normal(0,1, size = (100,128))
        I = torch.eye(128)

        self.conv = nn.Conv2d(1,100,9, bias = False, stride =1, padding = 6)
        self.mean2 = nn.Conv2d(1,1,13, bias = False, stride = 1, padding = 6)
        self.diffms = nn.Conv2d(1,25,9, bias=False, stride = 1, padding=6)

        self.wd = nn.Conv2d(100,128,1,bias = False, stride = 1)
        self.usd1 = nn.Conv2d(128, 128, 1, bias = False, stride=1)
        self.ud = nn.Conv2d(128,25,1,bias=False,stride=1)
        self.addp = nn.Conv2d(16,1,1, bias = False, stride = 1)

        if train: #If you are currently training the model
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13), requires_grad = False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T), requires_grad=True)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)), requires_grad=True)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx), requires_grad=True)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06, requires_grad=True)

        else:
            self.conv.weight = torch.nn.Parameter(torch.ones(100,1,9,9),requires_grad=False)
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13),requires_grad=False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T),requires_grad=False)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)),requires_grad=False)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx),requires_grad=False)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06,requires_grad=False)


    def forward(self, x, k, sy=9, sg=5):
        im_mean = self.mean2(x)
        # print(f'im_mean shape {im_mean.shape}')
        diffms = self.diffms(x)
        # print(f'diffms shape: {diffms.shape}')

        n, c, h, w = x.shape
        # y = torch.zeros(n, 100, h-8, w-8)
        x = self.conv(x)
        # print(f'post conv shape {x.shape}')
        #print(f'conv max {x.max()}')
        x=x+1

        x = x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True)
        # print(f'post vector norm shape: {x.shape}')
        #print(f'postnorm max {x.max()}')

        x = self.wd(x)
        #print(f'conv wd {x.max()}')
        z = self.ShLU(x,1)
        #print(f'conv SHLU {x.max()}')

        # Go through LISTA
        for i in range(k):
            z = self.ShLU(self.usd1(z)+x,1)

        x = self.ud(z)
        #print(f'ud max {x.max()}')
        # print(f'post ud shape {x.shape}')
        x = (x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True))*torch.linalg.vector_norm(diffms, ord=2, dim=1, keepdim=True)*1.1
        # print(f'prereassembled x shape {x.shape}')
        x = self.reassemble2(x,im_mean,4)
        # print(f'reassembled x shape {x.shape}')
        x = self.addp(x)
        #print(f'x.reassemble.max = {x.max()}')
        x = x+im_mean

        return x

    def reassemble2(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        
        # img_stack=torch.zeros(s,25,h,w)
        img_stack=torch.zeros(s,16,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    img_stack[q,filt,:,:] = x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack
    
    def create_diffms(self, kern_size, sy=5):
        diffms = torch.zeros(sy**2,1,kern_size,kern_size)
        
        neg = -1*(1/(sy**2))
        pos = 1+neg
        
        border = int((kern_size-sy)/2)
        base = torch.zeros(sy,sy)+neg
        cnt=0
        
        for i in range(sy**2):
            base = torch.zeros(sy**2)+neg
            base[cnt]=pos
            diffms[i,0,border:(kern_size-border),border:(kern_size-border)] = base.reshape([sy,sy])
            cnt+=1
        return diffms
    
    
    def create_gaus(self, kern_size, sy=9,std=2.15):
        n = torch.arange(0,sy)-(sy-1.0)/2.0
        sig2 = 2 * std * std
        gkern1d = torch.exp(-n ** 2 / sig2)
        gkern1d = gkern1d/torch.sum(gkern1d)
        #print(gkern1d.shape)
        gkern2d = torch.outer(gkern1d, gkern1d)
    

        # Wrap in zeros, if kern_size > sy
        gaussian_filter = torch.zeros(1,1,kern_size,kern_size)
        border = int((kern_size-sy)/2)
        gaussian_filter[0,0,border:(kern_size-border),border:(kern_size-border)] = gkern2d#(sy,std=std)
        #print(gaussian_filter.shape)
        return gaussian_filter
        
    
    def fixed_positions(self, tens, mult, sg):
        f, _ , h, w = tens.shape
        new_filt = torch.zeros(f*mult, 1, sg,sg)
        cnt = 0
        filt = 0
        
        for filt in range(f):
            for j in range((sg-w)+1):
                for i in range((sg-h)+1):
                    new_filt[cnt,0,i:i+h,j:j+w] = tens[filt]
                    cnt+=1
        return new_filt
    
    def expand_params(self,tens):
        return torch.unsqueeze(torch.unsqueeze(tens,2),3)
    
    def ShLU(self,a, th):
        return torch.sign(a)*torch.maximum(abs(a)-th, torch.tensor(0))

# Set Optimization Parameters

In [None]:
net = SCN(9,5,train=True)
criterion = nn.MSELoss()

optimizer = optim.SGD(
    [
        {"params": net.addp.parameters()},#, "lr": 0.0002, "momentum": 0.00005},
        {"params": net.conv.parameters()},#, "lr": 0.0003, "momentum": 0.0001},
        {"params": net.wd.parameters()},
        {"params": net.usd1.parameters()},
        {"params": net.ud.parameters()},
    ],
    lr=0.00007, momentum = 0.0001
)

## Generate Data for Training

In [None]:
sr_train = sr_gen('./data/raw/nii_sub_HR/','./data/raw/HR_output/','./data/raw/LR_output/')

In [None]:
temp = sr_train.get_template()
temp["patch"]=20
temp["step"]=10
temp.save_template(temp)

sr_train.run(clear=True)

## Create Training Dataset and Dataloader

In [None]:
# TODO: There's not reasing I can't combine the Dataset class and my custom class into one thing

class Dataset(torch.utils.data.Dataset):
    def __init__(self, sr_class):
        self.sr_class = sr_class

        # In case I forget to run match_altered before pulling the class
        if not sr_class.HR_files:
            sr_class.match_altered(update=True)

    def __len__(self):
        return len(self.sr_class.HR_files)

    def __getitem__(self, index):
        Y, X = self.sr_class.load_image_pair(index)
        X = torch.unsqueeze(torch.tensor(X, dtype=torch.float32),0)
        Y = torch.unsqueeze(torch.tensor(Y, dtype=torch.float32),0)

        return X, Y

In [None]:
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 2}

training_set = Dataset(sr_train)
training_generator = torch.utils.data.DataLoader(training_set, **params)

## Training Loop

In [None]:
from tqdm import tqdm
import time
# Loop over epochs

max_epochs = 20

for epoch in tqdm(range(max_epochs)):
    losses = []
    losses_per = []

    # Training
    count = 0
    for inp, goal in training_generator:
        optimizer.zero_grad()

        output = net(inp,2) # the 2 is the number of iterations in the LISTA network
        print(output.shape)
        output = torch.clamp(output, 0, 255)

        loss = criterion(output,goal)
        loss.backward()
        optimizer.step()
        print(f'loss = {loss.item()}')
        losses.append(loss.item())
        print(f'mini-batch # {count}, mean loss = {sum(losses)/len(losses)}')
        count = count+1

    torch.save(net.state_dict(), f'./MRI_save_{epoch}.p')
    print(f'\n\n epoch {epoch}, loss mean: {sum(losses)/len(losses)}, loss: {min(losses)}-{max(losses)}\n')

    # Give computer time to cool down
    time.sleep(10)

## Testing Loop