### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
from PIL import Image
import cv2
import pickle as pickle
from torchvision import transforms

## Model trained on BSD100 data

If you want to load the trained model set ?? to True

In [2]:
class SCN(nn.Module):
    def __init__(self,sy,sg, model_file=False, train=True):
        super().__init__()
        #Initialize layer weights:
        C = 5
        L = 5
        
        Dx = torch.normal(0,1, size = (25,128))
        Dy = torch.normal(0,1,size = (100,128))
        I = torch.eye(128)
        
        self.conv = nn.Conv2d(1,100,9, bias=False, stride=1, padding=0)
        self.mean2 = nn.Conv2d(1,1,13, bias=False, stride=1, padding=0)
        self.diffms = nn.Conv2d(1,25, 9, bias=False, stride=1, padding=0)
        
        self.wd = nn.Conv2d(100,128, 1,bias=False,stride=1)
        self.usd1 = nn.Conv2d(128, 128, 1, bias=False, stride=1)
        self.ud = nn.Conv2d(128,25, 1, bias=False, stride=1)
        self.addp = nn.Conv2d(25,1,1,bias=False, stride=1)#, padding=4)

        if train:
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13),requires_grad=False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T))
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)))
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx))
            self.addp.weight = torch.nn.Parameter(torch.ones(1,25,1,1)*0.04)#, requires_grad=False)
        else:
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13),requires_grad=False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T),requires_grad=False)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)),requires_grad=False)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx),requires_grad=False)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,25,1,1)*0.04,requires_grad=False)#, requires_grad=False)
        
    def forward(self, x, k, sy=9,sg=5):
        print(f'x = {x.shape}')
        print({x.max()})
        im_mean = self.mean2(x)
        diffms = self.diffms(x)
        print(f'x max = {x.max()}')
        print(f'x min = {x.min()}')
        
        
        #pad data out
        x = self.conv(x)
        
        x = x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True)
        
        x = self.wd(x)
        z = self.ShLU(x,1)        
        
        #Go through LISTA
        for i in range(k):
            z = self.ShLU(self.usd1(z)+x,1)
        x = self.ud(z)

        x = (x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True))*torch.linalg.vector_norm(diffms, ord=2, dim=1, keepdim=True)*1.1
        x = self.reassemble2(x,im_mean,sg)
        x = self.addp(x)

        x = x+im_mean
        print(f'final addition max = {x.max()}')
        print(f'final addition min = {x.min()}')
        return x
    
    
    def reassemble(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        cnt = 0
        
        img_stack=torch.zeros(s,c,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    #print(x[q,filt,jj:(jj+h), ii:(ii+w)].shape)
                    img_stack[q,0,:,:] = img_stack[q,0,:,:]+x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack+img
    
    def reassemble2(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        cnt = 0
        
        img_stack=torch.zeros(s,25,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    img_stack[q,filt,:,:] = x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack
    
    
    def create_diffms(self, kern_size, sy=5):
        diffms = torch.zeros(sy**2,1,kern_size,kern_size)
        
        neg = -1*(1/(sy**2))
        pos = 1+neg
        
        border = int((kern_size-sy)/2)
        base = torch.zeros(sy,sy)+neg
        cnt=0
        
        for i in range(sy**2):
            base = torch.zeros(sy**2)+neg
            base[cnt]=pos
            diffms[i,0,border:(kern_size-border),border:(kern_size-border)] = base.reshape([sy,sy])
            cnt+=1
        return diffms
    
    
    def create_gaus(self, kern_size, sy=9,std=2.15):
        n = torch.arange(0,sy)-(sy-1.0)/2.0
        sig2 = 2 * std * std
        gkern1d = torch.exp(-n ** 2 / sig2)
        gkern1d = gkern1d/torch.sum(gkern1d)
        #print(gkern1d.shape)
        gkern2d = torch.outer(gkern1d, gkern1d)
    

        # Wrap in zeros, if kern_size > sy
        gaussian_filter = torch.zeros(1,1,kern_size,kern_size)
        border = int((kern_size-sy)/2)
        gaussian_filter[0,0,border:(kern_size-border),border:(kern_size-border)] = gkern2d#(sy,std=std)
        #print(gaussian_filter.shape)
        return gaussian_filter
    
    def fixed_positions(self, tens, mult, sg):
        f, _ , h, w = tens.shape
        new_filt = torch.zeros(f*mult, 1, sg,sg)
        #new_filt = torch.zeros(1,f*mult,sg,sg)
        cnt = 0
        filt = 0
        
        for filt in range(f):
            for j in range((sg-w)+1):
                for i in range((sg-h)+1):
                    new_filt[cnt,0,i:i+h,j:j+w] = tens[filt]
                    #new_filt[0,cnt,i:i+h,j:j+w] = tens[filt]
                    cnt+=1
        return new_filt
    
    
    
    def expand_params(self,tens):
        return torch.unsqueeze(torch.unsqueeze(tens,2),3)
    
    def ShLU(self,a, th):
        return torch.sign(a)*torch.maximum(abs(a)-th, torch.tensor(0))

## Instantiate network and load saved weights

In [3]:
net = SCN(9,5,train=False)
net.load_state_dict(torch.load('/home/ross/Documents/CODE/sparse/SCN_BSD_final.p'))

## Model Trained on MRI data

If you want to load the traind model set ?? to True

In [16]:
class SCN(nn.Module):
    def __init__(self,sy,sg, model_file=False, train=True):
        super().__init__()
        #Initialize layer weights:
        C = 5
        L = 5
        
        Dx = torch.normal(0,1, size = (25,128))
        Dy = torch.normal(0,1,size = (100,128))
        I = torch.eye(128)
        #self.conv = nn.Conv2d(1,4,sg, bias=False, stride=1, padding=0)
        #self.conv = nn.Conv2d(1,6,sg+1, bias=False, stride=1, padding=0)
        self.conv = nn.Conv2d(1,100,9,bias=False, stride=1,padding=0)
        self.mean2 = nn.Conv2d(1,1,13, bias=False, stride=1, padding=0)
        self.diffms = nn.Conv2d(1,25, 9, bias=False, stride=1, padding=0)
        
        self.wd = nn.Conv2d(100,128, 1,bias=False,stride=1)
        self.usd1 = nn.Conv2d(128, 128, 1, bias=False, stride=1)
        self.ud = nn.Conv2d(128,25, 1, bias=False, stride=1)
        #self.addp = nn.Conv2d(25,1,1,bias=False, stride=1)#, padding=4)
        self.addp = nn.Conv2d(16,1,1,bias=False, stride=1)#, padding=4)
        
        if train:
            #self.conv.weight = torch.nn.Parameter(self.create_haar(sg,4))
            #self.conv.weight = torch.nn.Parameter(self.create_haar(6,6),requires_grad=True)
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13),requires_grad=False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T), requires_grad=True)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)), requires_grad=True)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx), requires_grad=True)
            #self.addp.weight = torch.nn.Parameter(torch.ones(1,25,1,1)*0.04)#, requires_grad=False)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06, requires_grad=True)

        else:
            #self.conv.weight = torch.nn.Parameter(self.create_haar(sg,4),requires_grad=False)
            #self.conv.weight = torch.nn.Parameter(self.create_haar(6,6),requires_grad=False)
            self.conv.weight = torch.nn.Parameter(torch.ones(100,1,9,9),requires_grad=False)
            self.mean2.weight = torch.nn.Parameter(self.create_gaus(13),requires_grad=False)
            self.diffms.weight = torch.nn.Parameter(self.create_diffms(9,5),requires_grad=False)
            self.wd.weight = torch.nn.Parameter(self.expand_params(C*Dy.T),requires_grad=False)
            self.usd1.weight = torch.nn.Parameter(self.expand_params(I - torch.matmul(Dy.T,Dy)),requires_grad=False)
            self.ud.weight = torch.nn.Parameter(self.expand_params((1/(C*L))*Dx),requires_grad=False)
            #self.addp.weight = torch.nn.Parameter(torch.ones(1,25,1,1)*0.04,requires_grad=False)
            self.addp.weight = torch.nn.Parameter(torch.ones(1,16,1,1)*0.06,requires_grad=False)

            #self.upstep = nn.Conv2d(1,100,sy,bias=False, stride=1, padding=0)

        
    def forward(self, x, k, sy=9,sg=5):
        im_mean = self.mean2(x)
        diffms = self.diffms(x)

        cnt=0
        n, c, h, w = x.shape
        y = torch.zeros(n,100,h-8,w-8)
        x = self.conv(x)
        x = x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True)
        
        x = self.wd(x)
        z = self.ShLU(x,1)
        
        #Go through LISTA
        for i in range(k):
            z = self.ShLU(self.usd1(z)+x,1)

        x = self.ud(z)
        x = (x/torch.linalg.vector_norm(x, ord=2, dim=1, keepdim=True))*torch.linalg.vector_norm(diffms, ord=2, dim=1, keepdim=True)*1.1
        x = self.reassemble2(x,im_mean,4)
        x = self.addp(x)
        print(f'x.reassemble.max = {x.max()}')
        x = x+im_mean
        
        return x
    
    
    def reassemble(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        cnt = 0
        img_stack=torch.zeros(s,c,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    img_stack[q,0,:,:] = img_stack[q,0,:,:]+x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack+img
    
    def reassemble2(self, x, im_mean, patch_size):
        img = im_mean
        s, c, h, w = img.shape
        cnt = 0
        
        
#        img_stack=torch.zeros(s,25,h,w)
        img_stack=torch.zeros(s,16,h,w)
        
        #go through every sample and reassemble the image
        for q in range(x.shape[0]):
            filt = 0
            for ii in range(patch_size-1, -1, -1):
                for jj in range(patch_size-1, -1, -1):
                    img_stack[q,filt,:,:] = x[q,filt,jj:(jj+h), ii:(ii+w)]
                    filt+=1
        
        return img_stack
    
    
    def create_diffms(self, kern_size, sy=5):
        diffms = torch.zeros(sy**2,1,kern_size,kern_size)
        
        neg = -1*(1/(sy**2))
        pos = 1+neg
        
        border = int((kern_size-sy)/2)
        base = torch.zeros(sy,sy)+neg
        cnt=0
        
        for i in range(sy**2):
            base = torch.zeros(sy**2)+neg
            base[cnt]=pos
            diffms[i,0,border:(kern_size-border),border:(kern_size-border)] = base.reshape([sy,sy])
            cnt+=1
        return diffms
    
    
    def create_gaus(self, kern_size, sy=9,std=2.15):
        n = torch.arange(0,sy)-(sy-1.0)/2.0
        sig2 = 2 * std * std
        gkern1d = torch.exp(-n ** 2 / sig2)
        gkern1d = gkern1d/torch.sum(gkern1d)
        #print(gkern1d.shape)
        gkern2d = torch.outer(gkern1d, gkern1d)
    

        # Wrap in zeros, if kern_size > sy
        gaussian_filter = torch.zeros(1,1,kern_size,kern_size)
        border = int((kern_size-sy)/2)
        gaussian_filter[0,0,border:(kern_size-border),border:(kern_size-border)] = gkern2d#(sy,std=std)
        #print(gaussian_filter.shape)
        return gaussian_filter
    
    def create_haar(self, kern_size, num):
        haar = torch.zeros(num,1,kern_size,kern_size)
        haar[0,0,:,:] = torch.tensor([[-1,-1,-1,1,1,1],
                                      [-1,-1,-1,1,1,1],
                                      [-1,-1,-1,1,1,1],
                                      [-1,-1,-1,1,1,1],
                                     [-1,-1,-1,1,1,1],
                                     [-1,-1,-1,1,1,1]])
        
        haar[1,0,:,:] = torch.tensor([[-1,-1,-1,-1,-1,-1],
                                      [-1,-1,-1,-1,-1,-1],
                                      [-1,-1,-1,-1,-1,-1],
                                      [1,1,1,1,1,1],
                                      [1,1,1,1,1,1],
                                     [1,1,1,1,1,1]])
        
        haar[2,0,:,:] = torch.tensor([[1,1,-1,-1,1,1],
                                      [1,1,-1,-1,1,1],
                                      [1,1,-1,-1,1,1],
                                      [1,1,-1,-1,1,1],
                                     [1,1,-1,-1,1,1],
                                     [1,1,-1,-1,1,1]])
        
        haar[3,0,:,:] = torch.tensor([[1,1,1,1,1,1],
                                      [1,1,1,1,1,1],
                                      [-1,-1,-1,-1,-1,-1],
                                      [-1,-1,-1,-1,-1,-1],
                                     [1,1,1,1,1,1],
                                     [1,1,1,1,1,1]])
        
        haar[4,0,:,:] = torch.tensor([[1,1,1,-1,-1,-1],
                                      [1,1,1,-1,-1,-1],
                                      [1,1,1,-1,-1,-1],
                                      [-1,-1,-1,1,1,1],
                                     [-1,-1,-1,1,1,1],
                                     [-1,-1,-1,1,1,1]])
        haar[5,0,:,:] = torch.tensor([[1,1,1,1,1,1],
                                      [1,1,1,1,1,1],
                                      [1,1,-1,-1,1,1],
                                      [1,1,-1,-1,1,1],
                                      [1,1,1,1,1,1],
                                     [1,1,1,1,1,1]])
        return haar
        
    
    def fixed_positions(self, tens, mult, sg):
        f, _ , h, w = tens.shape
        new_filt = torch.zeros(f*mult, 1, sg,sg)
        cnt = 0
        filt = 0
        
        for filt in range(f):
            for j in range((sg-w)+1):
                for i in range((sg-h)+1):
                    new_filt[cnt,0,i:i+h,j:j+w] = tens[filt]
                    cnt+=1
        return new_filt
    
    def expand_params(self,tens):
        return torch.unsqueeze(torch.unsqueeze(tens,2),3)
    
    def ShLU(self,a, th):
        return torch.sign(a)*torch.maximum(abs(a)-th, torch.tensor(0))

In [17]:
net = SCN(9,5,train=False)
net.load_state_dict(torch.load('/home/ross/Documents/CODE/sparse/sparse_SR/nn-code/MRI_run_SGD_v100_57.p'))

<All keys matched successfully>

## Load Dataset Trainer

In [None]:
from torchvision import transforms
class Dataset(torch.utils.data.Dataset):
  #'Characterizes a dataset for PyTorch'
    def __init__(self, in_dir, inputs, tar_dir, targets):
        #'Initialization'
        self.list_input = inputs
        self.list_target = targets
        self.in_dir = in_dir
        self.tar_dir = tar_dir
        
    def __len__(self):
        #'Denotes the total number of samples'
        return len(self.list_input)

    def __getitem__(self, index):
        #'Generates one sample of data'
        # Select sample
        inp = self.list_input[index]
        target = self.list_target[index]

        # Load data and get label
        #X = torch.load(self.in_dir + inp)
        #Y = torch.load(self.tar_dir + target)
        im_l = np.array(Image.open(self.in_dir + inp))[:,:,:3]
        im_gt = np.array(Image.open(self.tar_dir + target))[:,:,:3]
        
        s=2
        
        im_l = im_l/255.0
        if len(im_l.shape)==3 and im_l.shape[2]==3:
            im_l_ycbcr = rgb2ycbcr(im_l)
        else:
            im_l_ycbcr = np.zeros([im_l.shape[0], im_l.shape[1], 3])
            im_l_ycbcr[:, :, 0] = im_l
            im_l_ycbcr[:, :, 1] = im_l
            im_l_ycbcr[:, :, 2] = im_l


        #im_l_y is the luminance values of the image
        im_l_y = im_l_ycbcr[:, :, 0]*255 #[16 235]
        #im_l_y = imresize(im_l_y,s)
        im_l_y = ExtendBorder(im_l_y,6)


        if len(im_gt.shape)==3:
            im_gt_ycbcr = rgb2ycbcr(im_gt/255.0)*255.0
            im_gt_y = im_gt_ycbcr[:, :, 0]
        else:
            im_gt_y = im_gt
        
        
        convert_tensor = transforms.ToTensor()
        X = torch.unsqueeze(torch.tensor(im_l_y, dtype=torch.float32),0)
        #X = convert_tensor(im_l_y)
        Y = torch.unsqueeze(torch.tensor(im_gt_y, dtype=torch.float32),0)
        #Y = convert_tensor(im_gt_y)
        
        return X, Y

## Training Loop

In [20]:
def upscale(im_l, s, nn=True):
    """
    % im_l: LR image, float np array in [0, 255]
    % im_h: HR image, float np array in [0, 255]
    """
    im_l = im_l/255.0
    if len(im_l.shape)==3 and im_l.shape[2]==3:
        im_l_ycbcr = rgb2ycbcr(im_l)
    else:
        im_l_ycbcr = np.zeros([im_l.shape[0], im_l.shape[1], 3])
        im_l_ycbcr[:, :, 0] = im_l
        im_l_ycbcr[:, :, 1] = im_l
        im_l_ycbcr[:, :, 2] = im_l
    

    #im_l_y is the luminance values of the image
    im_l_y = im_l_ycbcr[:, :, 0]*255 #[16 235]
    if nn==True:
        im_l_y = imresize(im_l_y,s)
        print(im_l_y.max())
        print(im_l_y.min())
        im_l_y = ExtendBorder(im_l_y,6)
        im_l_y = torch.unsqueeze(torch.unsqueeze(torch.tensor(im_l_y, dtype=torch.float32),0),0)
        im_h_y = net(im_l_y,2)
        #im_h_y = self.upscale_alg(im_l_y, s)
        im_h_y = torch.squeeze(torch.squeeze(im_h_y,0),0)
        im_h_y = im_h_y.detach().numpy()
    else:
        im_h_y = imresize(im_l_y, s)

    # recover color
    if len(im_l.shape)==3:
        im_ycbcr = imresize(im_l_ycbcr, s)
        im_ycbcr[:, :, 0] = im_h_y/255.0; #[16/255 235/255]
        im_h = ycbcr2rgb(im_ycbcr)*255.0
    else:
        im_h = im_h_y

    im_h = np.clip(im_h, 0, 255)
    im_h_y = np.clip(im_h_y, 0, 255)
    return im_h,im_h_y

def ExtendBorder(im, offset):
    sz = im.shape
    assert(len(sz)==2)

    im2 = np.zeros([sz[0]+offset*2, sz[1]+offset*2])
    im2[ offset:-offset, offset:-offset ] = im
    im2[ offset:-offset, 0:offset ] = im[:, offset:0:-1]
    im2[ offset:-offset, -offset: ] = im[:, -2:-(offset+2):-1]
    im2[ 0:offset, :] = im2[2*offset:offset:-1, :]
    im2[ -offset:, :] = im2[-(offset+2):-(2*offset+2):-1, :]

    return im2

def imresize(im_l, s):
    if s<1:
        im_l = cv2.GaussianBlur(im_l, (7,7), s)
    im_h = cv2.resize(im_l, (0,0), fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
    return im_h

def rgb2ycbcr(im_rgb):
    im_rgb = im_rgb.astype(np.float32)
    im_ycrcb = cv2.cvtColor(im_rgb, cv2.COLOR_RGB2YCR_CB) #converts the RBG colors into YCbCr colorspace
    im_ycbcr = im_ycrcb[:,:,(0,2,1)].astype(np.float32)
    im_ycbcr[:,:,0] = (im_ycbcr[:,:,0]*(235-16)+16)/255.0 #to [16/255, 235/255]
    im_ycbcr[:,:,1:] = (im_ycbcr[:,:,1:]*(240-16)+16)/255.0 #to [16/255, 240/255]
    return im_ycbcr

def ycbcr2rgb(im_ycbcr):
    im_ycbcr = im_ycbcr.astype(np.float32)
    im_ycbcr[:,:,0] = (im_ycbcr[:,:,0]*255.0-16)/(235-16) #to [0, 1]
    im_ycbcr[:,:,1:] = (im_ycbcr[:,:,1:]*255.0-16)/(240-16) #to [0, 1]
    im_ycrcb = im_ycbcr[:,:,(0,2,1)].astype(np.float32)
    im_rgb = cv2.cvtColor(im_ycrcb, cv2.COLOR_YCR_CB2RGB)
    return im_rgb

def Shave(im, border):
    if isinstance(border, int):
        border=[border, border]
    im = im[border[0]:-border[0], border[1]:-border[1], ...]
    return im

def modcrop(im, modulo):
    sz = im.shape
    h = int(sz[0]/modulo*modulo)
    w = int(sz[1]/modulo*modulo)
    ims = im[0:h, 0:w, ...]
    return ims

def evalimg(im_h_y, im_gt, shav=0):
    if len(im_gt.shape)==3:
        im_gt_ycbcr = rgb2ycbcr(im_gt/255.0)*255.0
        im_gt_y = im_gt_ycbcr[:, :, 0]
    else:
        im_gt_y = im_gt

    im_h_y_uint8 = np.rint( np.clip(im_h_y, 0, 255))
    im_gt_y_uint8 = np.rint( np.clip(im_gt_y, 0, 255))
    diff = im_h_y_uint8 - im_gt_y_uint8
    #diff = im_h_y - im_gt_y
    if shav>0:
        diff = Shave(diff, [shav, shav])
    res = {}
    res['rmse'] = np.sqrt((diff**2).mean())
    res['psnr'] = 20*np.log10(255.0/res['rmse'])
    return res

## Set Optimization parameters

In [None]:
net = SCN(9,5,train=True)
criterion = nn.MSELoss()

optimizer = optim.SGD(
    [
        {"params": net.addp.parameters()},#, "lr": 0.0002, "momentum": 0.00005},
        {"params": net.conv.parameters()},#, "lr": 0.0003, "momentum": 0.0001},
        {"params": net.wd.parameters()},
        {"params": net.usd1.parameters()},
        {"params": net.ud.parameters()},
    ],
    lr=0.00007, momentum = 0.0001
)

params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 0}

In [None]:
pat = '/home/ross/Documents/CODE/sparse/sparse_SR/nn-code/data'

training_set = Dataset(f'{pat}/shuff_nii_sub_LR_patches/',os.listdir(f'{pat}/shuff_nii_sub_LR_patches'),f'{pat}/shuff_nii_sub_HR_patches/', os.listdir(f'{pat}/shuff_nii_sub_LR_patches'))
training_generator = torch.utils.data.DataLoader(training_set, **params)

In [None]:
import time
UP_SCALE=2
SHAVE=1 #set 1 to be consistant with SRCNN

# load inputs
im_gt = []
im_l = []

max_epochs = 3

from tqdm import tqdm
# Loop over epochs
for epoch in tqdm(range(max_epochs)):
    losses=[]
    losses_per = []
    # Training
    count=0
    for inp, goal in training_generator:
        # Transfer to GPU
        optimizer.zero_grad()
    
        #print(inp.type())
        output = net(inp,2)
        output = torch.clamp(output, 0, 255)
        
        #print(f'goal max = {goal.max()}')
        loss = criterion(output,goal)
        loss.backward()
        optimizer.step()
        print(f'loss = {loss.item()}')
        losses.append(loss.item())
        print(f'mini-batch # {count}, mean loss = {sum(losses)/len(losses)}')
        count= count+1
    
    torch.save(net.state_dict(), f'/home/ross/Documents/CODE/sparse/sparse_SR/nn-code/MRI_run_SGD_v100_{epoch+69}.p')
    print(f'\n\n epoch {epoch}, loss mean: {sum(losses)/len(losses)}, loss: {min(losses)}-{max(losses)}\n')
    #print(net.conv.weight)
    time.sleep(160)
    

## Testing Loop

In [21]:
UP_SCALE=2
SHAVE=1 #set 1 to be consistant with SRCNN
save_SR = False

# load inputs
im_gt = []
im_l = []

lr_dir = '/home/ross/Documents/CODE/sparse/sparse_SR/test_images_LR'
gt_dir = '/home/ross/Documents/CODE/sparse/sparse_SR/test_images_HR'
out_dir = '/home/ross/Documents/CODE/sparse/sparse_SR/nn-code/data/outputs_sag'


IMAGE_FILE=os.listdir(lr_dir)#'./example_test/data/slena.bmp'
IMAGE_GT_FILE=os.listdir(gt_dir)#'./example_test/data/mlena.bmp'

#files_gt = IMAGE_GT_FILE
for f in IMAGE_GT_FILE:
    #print 'loading', f
    im = np.array(Image.open(f'{gt_dir}/{f}'))[:,:,:3]
    im = modcrop(im, UP_SCALE).astype(np.float32)
    im_gt += [im]

for f in IMAGE_FILE:
    #assert(len(im_gt)==1)
    #im_l += [np.array(Image.open(IMAGE_FILE)).astype(np.float32)]
    im_l += [np.array(Image.open(f'{lr_dir}/{f}')).astype(np.float32)[:,:,:3]]
    
res_all = []
res_all_b=[]
for i in range(len(im_l)):
    im_h, im_h_y=upscale(im_l[i], UP_SCALE)
    
    # evaluation
    if SHAVE==1:
        shave = round(UP_SCALE)
    else:
        shave = 0
    print(IMAGE_FILE[i])
    print(IMAGE_GT_FILE[i])
    res = evalimg(im_h_y, im_gt[i], shave)
    res_all += [res]
    print('evaluation against {}, rms={:.4f}, psnr={:.4f}'.format(IMAGE_GT_FILE[i], res['rmse'], res['psnr']))

    ##Make Bicubic Image as comparison
    im_h_b, im_h_b_y = upscale(im_l[i], UP_SCALE, nn=False)

    if SHAVE==1:
        shave = round(UP_SCALE)
    else:
        shave = 0
    res_b = evalimg(im_h_b_y, im_gt[i], shave)
    res_all_b += [res_b]
    print('bicubic evaluation against {}, rms={:.4f}, psnr={:.4f}'.format(IMAGE_FILE[i], res_b['rmse'], res_b['psnr']))

    ##DONE ADDED

    # save
    if save_SR:
        img_name = os.path.splitext(os.path.basename(IMAGE_GT_FILE[i]))[0]
        Image.fromarray(np.rint(im_h).astype(np.uint8)).save(f'{out_dir}/{img_name}_x{UP_SCALE}.png')
        Image.fromarray(np.rint(im_h_b).astype(np.uint8)).save(f'{out_dir}/{img_name}_x{UP_SCALE}_bic.png')

        
print('mean SCN PSNR:', np.array([_['psnr'] for _ in res_all]).mean())
print('mean bicubic PSNR:', np.array([_['psnr'] for _ in res_all_b]).mean())

175.4183
15.002569
x.reassemble.max = 51.64139938354492
138_2-0-0-0.png
138_2-0-0-0.png
evaluation against 138_2-0-0-0.png, rms=2.7250, psnr=39.4236
bicubic evaluation against 138_2-0-0-0.png, rms=3.0962, psnr=38.3143
167.75838
14.892883
x.reassemble.max = 56.16862106323242
109_2-0-0-0.png
109_2-0-0-0.png
evaluation against 109_2-0-0-0.png, rms=3.5079, psnr=37.2298
bicubic evaluation against 109_2-0-0-0.png, rms=3.9509, psnr=36.1970
167.88416
14.963648
x.reassemble.max = 51.8637580871582
132_2-0-0-0.png
132_2-0-0-0.png
evaluation against 132_2-0-0-0.png, rms=3.5806, psnr=37.0517
bicubic evaluation against 132_2-0-0-0.png, rms=3.9567, psnr=36.1841
194.13962
14.185234
x.reassemble.max = 44.68369674682617
152_3-2-0-14.png
152_3-2-0-14.png
evaluation against 152_3-2-0-14.png, rms=4.5912, psnr=34.8923
bicubic evaluation against 152_3-2-0-14.png, rms=5.4251, psnr=33.4427
189.92552
14.75666
x.reassemble.max = 62.28560256958008
176_2-0-0-0.png
176_2-0-0-0.png
evaluation against 176_2-0-0-0.png