# Single-Image SR-CEST

## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import pandas as pd
import cv2
import glob
import os
import time
import numpy as np
import math
import albumentations as A
import matplotlib.pyplot as plt
from glob import glob
%matplotlib inline

from tqdm import tqdm
from scipy.io import savemat
from sklearn.model_selection import train_test_split
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision.utils import save_image

In [2]:
torch.manual_seed(0)
np.random.seed(0)

## Import data from folders

In [23]:
train_X = {}
train_Y = {}

In [24]:
def read_data(train_X, train_Y):

    folders = glob("/home/mri/Documents/dlsr/dataset/train/low-res/8x/*/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        train_X[folder] = files
    
    folders = glob("/home/mri/Documents/dlsr/dataset/train/high-res/*/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        train_Y[folder] = files
    
    return train_X, train_Y

train_X, train_Y = read_data(train_X, train_Y)

## Custom dataset

In [25]:
class dataset(Dataset):
    def __init__(self,data_dict,label_dict, transform=None):
        df = data_dict
        lf = label_dict
        self.datas = []
        self.labels = []
        
        print("Reading images to memory")
        
        roots = df.keys()
        for root in roots:
            fnames = df[root]
            for i in range(len(fnames)):
                data = np.expand_dims((plt.imread(os.path.join(root,fnames[i]))), axis =2)
                self.datas.append(data)

        roots = lf.keys()
        for root in roots:
            fnames = lf[root]
            for i in range(len(fnames)):
                label = np.expand_dims((plt.imread(os.path.join(root,fnames[i]))), axis =2)
                self.labels.append(label)
        
        if transform is None:
            self.transform = A.Compose([ToTensorV2()], 
                                       additional_targets={'label': 'image'})
        else:
            self.transform = transform            
        
    def __len__(self):
        return len(self.datas)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()
        data = self.datas[idx]
        label = self.labels[idx]
        transformed = self.transform(image=data,label=label)
        data,label = transformed['image'],transformed['label']
        data, label = torch.div(data,255), torch.div(label, 255)
        
        return data, label

## Prepare and load data

In [26]:
batch_size = 32
epochs = 10000

In [27]:
transforms = A.Compose([A.HorizontalFlip(p=0.5),
                      A.VerticalFlip(p=0.5),
                        A.augmentations.geometric.rotate.RandomRotate90(p=0.5),
                      ToTensorV2()
                      ],
                     additional_targets={'label':'image'})

train_data = dataset(train_X, train_Y, transform = transforms)
val_data = dataset(train_X, train_Y, transform = None)


num_train = len(train_data)
valid_percent = 0.2
valid_size = round(valid_percent * num_train)
train_size = num_train - valid_size
indices = list(range(num_train))
np.random.shuffle(indices)
train_idx, valid_idx = indices[:train_size], indices[train_size:]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)


train_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
valid_loader = DataLoader(val_data, batch_size=batch_size, sampler=valid_sampler)

Reading images to memory
Reading images to memory


In [28]:
val = round(len(train_data)*0.2)
tr = len(train_data) - val
print(val, tr)

0 0


## Model

In [4]:
class ConvReLU(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='replicate'),
            nn.PReLU()
        )
    def forward(self, x):
        return self.conv_relu(x)
    
class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        
        super().__init__()
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='replicate'),
            nn.PReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='replicate')
        )

    def forward(self, x):
        return self.double_conv(x)

class ChannelAttn(nn.Module):
    
    def __init__(self, channels, factor):
    
        super().__init__()
        self.ca = nn.Sequential(
            nn.Conv2d(channels, channels//factor, kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(channels//factor, channels, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = F.avg_pool2d(x, kernel_size=x.size(dim=-1))
        return self.ca(x)
    
class RCAB(nn.Module):
    def __init__(self, in_channels, out_channels, factor):
        
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.ca = ChannelAttn(out_channels, factor)     
    
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.ca(x1)
        x3 = (x1*x2)
        x = x+x3
        return x

class CasBlock(nn.Module):
    def __init__(self, channels):
        
        super().__init__()
        self.rcab = RCAB(channels, channels, factor=16)
        self.conv1 = nn.Conv2d(channels*2, channels, kernel_size=1)
        self.conv2 = nn.Conv2d(channels*3, channels, kernel_size=1)
        self.conv3 = nn.Conv2d(channels*4, channels, kernel_size=1)
        
    def forward(self, x):
        rcab1 = self.rcab(x)
        conv = self.conv1(torch.cat([rcab1, x], dim=1))
        rcab2 = self.rcab(conv)
        conv = self.conv2(torch.cat([rcab1, rcab2, x], dim=1))
        rcab3 = self.rcab(conv)
        conv = self.conv3(torch.cat([rcab1, rcab2, rcab3, x], dim=1))
        return conv
    
class Down(nn.Module): 
    
    def __init__(self, in_channels, mid_channels, out_channels):
        
        super().__init__()
        self.cas = CasBlock(in_channels)
        self.conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
        self.down = nn.MaxPool2d(2)
        
    def forward(self,x1, *args):
        x = self.cas(x1)
        xp = x
        in_size = x.size(dim=-1)
        for arg in args:
            if in_size == arg.size(dim=-1):
                x = torch.cat([arg,x], dim=1)
            elif in_size < arg.size(dim=-1):
                factor = arg.size(dim=-1) // in_size
                pool = nn.MaxPool2d(factor)
                arg = pool(arg)
                x = torch.cat([arg,x], dim=1)
            else:
                factor = in_size // arg.size(dim=-1)
                up = nn.Upsample(scale_factor=factor, mode='bicubic', align_corners=True)
                arg = up(arg)
                x = torch.cat([arg,x], dim=1)
        x = self.conv(x)
        x = self.down(x)
        return x, xp
        
class Up(nn.Module):
    
    def __init__(self, in_channels, mid_channels, out_channels):
        
        super().__init__()
        self.cas = CasBlock(in_channels)
        self.conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
        
    def forward(self,x1, *args):
        x = self.cas(x1)
        xp = x
        in_size = x.size(dim=-1)
        for arg in args:
            if in_size == arg.size(dim=-1):
                x = torch.cat([arg,x], dim=1)
            elif in_size < arg.size(dim=-1):
                factor = arg.size(dim=-1) // in_size
                pool = nn.MaxPool2d(factor)
                arg = pool(arg)
                x = torch.cat([arg,x], dim=1)
            else:
                factor = in_size // arg.size(dim=-1)
                up = nn.Upsample(scale_factor=factor, mode='bicubic', align_corners=True)
                arg = up(arg)
                x = torch.cat([arg,x], dim=1)
        x = self.conv(x)
        x = self.up(x)
        return x, xp
    
class Vanilla(nn.Module):
    
    def __init__(self, in_channels, mid_channels, out_channels):
        
        super().__init__()
        self.cas = CasBlock(in_channels)
        self.conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
        
    def forward(self,x1, *args):
        x = self.cas(x1)
        in_size = x.size(dim=-1)
        for arg in args:
            if in_size == arg.size(dim=-1):
                x = torch.cat([arg,x], dim=1)
            elif in_size < arg.size(dim=-1):
                factor = arg.size(dim=-1) // in_size
                pool = nn.MaxPool2d(factor)
                arg = pool(arg)
                x = torch.cat([arg,x], dim=1)
            else:
                factor = in_size // arg.size(dim=-1)
                up = nn.Upsample(scale_factor=factor, mode='bicubic', align_corners=True)
                arg = up(arg)
                x = torch.cat([arg,x], dim=1)
        x = self.conv(x)
        return x
    
class SingleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.inc = ConvReLU(1,64)
        self.down1 = Down(64,128,128)
        self.down2 = Down(128,256,256)
        self.down3 = Down(256,512,512)
        self.up1 = Up(512,1024,256)
        self.up2 = Up(256,1280,128)
        self.up3 = Up(128,1408,64)
        self.vanilla = Vanilla(64,1472,64)
        self.out = SingleConv(64,1)
        
    def forward(self, x):
        #in
        x0 = self.inc(x)
        
        #down1
        x1,xp1 = self.down1(x0,x0)
        
        #down2
        x2,xp2 = self.down2(x1,x0,xp1)
        
        #down3
        x3,xp3 = self.down3(x2,x0,xp1,xp2)
        
        #up1
        x4,xp4 = self.up1(x3,x0,xp1,xp2,xp3)
        
        #up2
        x5,xp5 = self.up2(x4,x0,xp1,xp2,xp3,xp4)
        
        #up3
        x6,xp6 = self.up3(x5,x0,xp1,xp2,xp3,xp4,xp5)
                
        #vanilla
        x7 = self.vanilla(x6,x0,xp1,xp2,xp3,xp4,xp5,xp6)
        
        #out
        out = self.out(x7)
        return out

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Computation device: ', device)
model = Net().to(device)
model.load＿state_dict(torch.load("pretrain/best-loss.pt"))
print(model)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

## Loss Function:

In [None]:
def gaussian_kernel(win_size, win_sigma):
    m = (win_size - 1.)/2.
    win = torch.arange(-m,m+1)
    win = np.exp(-(win**2)/(2*win_sigma**2))
    win_sum = win.sum()
    if win_sum != 0:
        win /= win_sum
    kernel = torch.outer(win,win)
    return kernel

def ssim(X, Y, win, K, max_val):
    C1 = (K[0] * max_val) ** 2
    C2 = (K[1] * max_val) ** 2
    win = win.to(X.device, dtype=X.dtype)
    
    mu_x = F.conv2d(F.pad(X,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0)
    mu_y = F.conv2d(F.pad(Y,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0)
    mu_xy = mu_x*mu_y
    mu_x2 = mu_x**2
    mu_y2 = mu_y**2
    
    sigma_x2 = (F.conv2d(F.pad(X*X,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0) - mu_x2)
    sigma_y2 = (F.conv2d(F.pad(Y*Y,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0) - mu_y2)
    sigma_xy = (F.conv2d(F.pad(X*Y,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0) - mu_xy)
    
    cs_map = (2 * sigma_xy + C2) / (sigma_x2 + sigma_y2 + C2)
    ssim_map = ((2 * mu_xy + C1) / (mu_x2 + mu_y2 + C1)) * cs_map
    
    ssim = torch.flatten(ssim_map, 2).mean(-1)
    cs = torch.flatten(cs_map, 2).mean(-1)
    return ssim, cs

def ms_ssim(X,Y, win_size = 11, win_sigma = 0.5, K = (0.01,0.03), max_val = 1, weights = None, batch_average = False):

    if weights is None:
        weights = [0.0517,0.3295,0.3462,0.2726]
    weights = torch.FloatTensor(weights).to(X.device, dtype=X.dtype)

    win = gaussian_kernel(win_size, win_sigma)
    win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    levels = weights.shape[0]
    mcs = []
    for i in range(levels):
        ssim_m, cs = ssim(X, Y, win=win, max_val=max_val, K=K)
        if i < levels - 1:
            mcs.append(torch.relu(cs))
            padding = [s % 2 for s in X.shape[2:]]
            X = F.avg_pool2d(X, kernel_size=2, padding=padding)
            Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)

    ssim_m = torch.relu(ssim_m)  
    mcs_and_ssim = torch.stack(mcs + [ssim_m], dim=0) 
    ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)
    if batch_average:
        return ms_ssim_val.mean()
    else:
        return ms_ssim_val
    
def gaussian_l1(X, Y, win_sigma =0.5, win_size = 11, max_val = 1, batch_average = False):
    win = gaussian_kernel(win_size, win_sigma)
    win = win.repeat([X.shape[1]] + [1]*(len(X.shape)-1))
    win = win.to(X.device, dtype=X.dtype)
    
    l1_map = F.l1_loss(X,Y, reduction='none')
    gaussian_l1_map = F.conv2d(l1_map, weight = win, stride = 1, padding = 0)
    gaussian_l1 = torch.flatten(gaussian_l1_map,2).mean(-1)
    if batch_average:
        return (gaussian_l1.mean()/max_val)
    else:
        return (gaussian_l1/max_val)
    
class hybrid(nn.Module):
    def __init__(self,
                 win_size = 11,
                 win_sigma = 1.5,
                 max_val = 1,
                 batch_average = False,
                 weights=None,
                 K=(0.01, 0.03),
                 alpha = 0.84,
                 compensation = 100.):
        super(hybrid, self).__init__()
        self.win_size = win_size
        self.win_sigma = win_sigma
        self.max_val = max_val
        self.batch_average = batch_average
        self.weights = weights
        self.K = K
        self.win = gaussian_kernel(self.win_size, self.win_sigma)
        self.alpha = alpha
        self.compensation = compensation
    
    def forward(self,X,Y):
        ms = 1 - ms_ssim(X,Y, win_size = self.win_size, win_sigma = self.win_sigma, K = self.K, max_val = self.max_val, weights = self.weights, batch_average = self.batch_average)
        l1 = gaussian_l1(X, Y, win_sigma =self.win_sigma, win_size = self.win_size, max_val = self.max_val, batch_average = self.batch_average)
        hybrid = self.alpha*ms + (1-self.alpha)*l1
        hybrid = hybrid.mean(0)
        hybrid = self.compensation*hybrid
        return hybrid

## Function for Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='single/8x/best-loss.pt', trace_func=print):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

## PSNR function

In [6]:
def psnr(label, outputs, max_val=1):
    
    label = label.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()
    img_diff = outputs - label
    rmse = math.sqrt(np.mean((img_diff) ** 2))
    if rmse == 0:
        return 100
    else:
        PSNR = 20 * math.log10(max_val / rmse)
        return PSNR

## Training Functions

In [None]:
def train(model, dataloader):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    for image_data, label in tqdm(dataloader, total=int(len(train_data)/dataloader.batch_size)):
        data = image_data.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        batch_psnr =  psnr(label, outputs)
        running_psnr += batch_psnr
    final_loss = running_loss/len(dataloader)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr

In [None]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    with torch.no_grad():
        for image_data, label  in tqdm(dataloader, total=int(len(val_data)/dataloader.batch_size)):
            data = image_data.to(device)
            label = label.to(device)
            outputs = model(data)
            loss = criterion(outputs, label)
            running_loss += loss.item()
            batch_psnr = psnr(label, outputs)
            running_psnr += batch_psnr
    final_loss = running_loss/len(dataloader)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr

## Model Training

In [None]:
lr = 0.00001
iterations = 0.8*len(train_data)/batch_size
step = int(100000/iterations)
optimizer = optim.Adam(model.parameters(), lr=lr) # no weight decay
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = step, gamma=0.5)
criterion = hybrid()

In [None]:
train_loss, val_loss = [], []
train_psnr, val_psnr = [], []
best_psnr = 0
early_stopping = EarlyStopping(patience=50, verbose=False)
start = time.time()
path = 'single/8x/best-model.pt'

for epoch in range(epochs):
    print(f"Epoch {epoch + 1} of {epochs}")
    train_epoch_loss, train_epoch_psnr = train(model, train_loader)
    scheduler.step()
    val_epoch_loss, val_epoch_psnr = validate(model, valid_loader)
    print(f"Train PSNR: {train_epoch_psnr:.3f}")
    print(f"Val PSNR: {val_epoch_psnr:.3f}")
    train_loss.append(train_epoch_loss)
    train_psnr.append(train_epoch_psnr)
    val_loss.append(val_epoch_loss)
    val_psnr.append(val_epoch_psnr)
    early_stopping(val_epoch_loss, model)
    if val_epoch_psnr > best_psnr:
        torch.save(model.state_dict(), path)
        best_psnr = val_epoch_psnr
    if early_stopping.early_stop:
        print("Early stopping")
        break
end = time.time()
print(f"Finished training in: {((end-start)/60):.3f} minutes")

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('single/8x/loss.png')
plt.show()
train_ls = pd.DataFrame({'train_loss':train_loss,'val_loss':val_loss},columns={'train_loss','val_loss'})
train_ls.to_csv('single/8x/loss.csv')
# psnr plots
plt.figure(figsize=(10, 7))
plt.plot(train_psnr, color='green', label='train PSNR dB')
plt.plot(val_psnr, color='blue', label='validataion PSNR dB')
plt.xlabel('Epochs')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.savefig('single/8x/psnr.png')
plt.show()
psnr_tr = pd.DataFrame({'train_psnr':train_psnr,'val_psnr':val_psnr},columns={'train_psnr','val_psnr'})
psnr_tr.to_csv('single/8x/psnr.csv')

## Testing

In [None]:
scale = '8'
device = 'cuda'
model = Net().to(device)
model.load＿state_dict(torch.load("single/"+scale+"x/best-model.pt"))

In [None]:
def read_test_data(path):
    test_X = {}
    test_Y = {}
    folders = glob("/home/mri/Documents/dlsr/dataset/test/low-res/"+scale+"x/"+path+"/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_X[folder] = files
        
    folders = glob("/home/mri/Documents/dlsr/dataset/test/high-res/"+path+"/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_Y[folder] = files
        
    return test_X,test_Y

In [None]:
ls = ['HRCEST07_1', '20220205_135439_DLSRCEST_C164R1_1_5_26_CWCEST_0.6uT_cor_-1p5', 
       'HRCEST07_5', 'HRCEST07_3', 
       '20200909_190127_DLCEST_C48_L1_5xFAD_1_17_6_10_7_11_examp_10_CWCEST_0.6uT_front', 
       '20220205_135439_DLSRCEST_C164R1_1_5_23_CWCEST_0.6uT_sag_0', 
       'HRCEST07_4', 
       '20220117_113031_SeWeon_7_Day26_L1_1_24_8_cestRARE_fullZ_0.8uT', 
       '20201201_114747_ICH_mouse_20201124_C78M2_day1_1_13_10_CWCEST_0.8uT_9696', 
       'HRCEST07_2']

os.mkdir(f"single/{scale}x/test")
for i in range(len(ls)):
    path = ls[i]
    test_X, test_Y = read_test_data(path)
    test_data = dataset(test_X, test_Y)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
    os.mkdir(f"single/{scale}x/test/{path}")
    os.mkdir(f"single/{scale}x/test/{path}/data")
    model.eval()
    count = 1
    with torch.no_grad():
        for batch in test_loader:
            images = batch[0]
            images = images.to(device)  
            outputs = model(images)
            outputs = outputs.cpu()
            tensor = outputs.squeeze()
            tensor = torch.clamp(tensor, min = 0.0, max = 1.0)
            tensor = tensor*255
            image = np.array(tensor, dtype=np.uint8)
            cv2.imwrite(f"single/{scale}x/test/{path}/data/{count:03d}.tif", image)
            count +=1

In [None]:
ls = ['HRCEST07_1', '20220205_135439_DLSRCEST_C164R1_1_5_26_CWCEST_0.6uT_cor_-1p5', 
       'HRCEST07_5', 'HRCEST07_3', 
       '20200909_190127_DLCEST_C48_L1_5xFAD_1_17_6_10_7_11_examp_10_CWCEST_0.6uT_front', 
       '20220205_135439_DLSRCEST_C164R1_1_5_23_CWCEST_0.6uT_sag_0', 
       'HRCEST07_4', 
       '20220117_113031_SeWeon_7_Day26_L1_1_24_8_cestRARE_fullZ_0.8uT', 
       '20201201_114747_ICH_mouse_20201124_C78M2_day1_1_13_10_CWCEST_0.8uT_9696', 
       'HRCEST07_2']

# os.mkdir(f"python-mat/single/{scale}x")

for i in range(len(ls)):
    path = ls[i]
    test_X, test_Y = read_test_data(path)
    test_data = dataset(test_X, test_Y)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
    model.eval()
    count = 1
    outputs=[]
    with torch.no_grad():
        for batch in test_loader:
            images = batch[0]
            images = images.to(device)  
            output = model(images)
            a = psnr(images, output)
            print(type(a))
            output = output.cpu()
            output = output.squeeze(dim=0)
            output = torch.clamp(output, min = 0.0, max = 1.0)
            output = np.array(output, dtype=np.single)
            if count == 1:
                outputs = output
            else:
                outputs = np.concatenate((outputs, output), axis=0)
            count +=1
    output_path = f"python-mat/single/{scale}x/{path}.mat"
    odict = {'img': outputs}
    # savemat(output_path, odict)

## Average Testing

In [20]:
scale = '8'
device = 'cuda'
model = Net().to(device)
model.load＿state_dict(torch.load("single/"+scale+"x/best-model.pt"))

<All keys matched successfully>

In [20]:
def read_test_data(path):
    test_X = {}
    test_Y = {}
    folders = glob("/home/mri/Documents/dlsr/dataset/test/low-res/"+scale+"x/"+path+"/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_X[folder] = files
        
    folders = glob("/home/mri/Documents/dlsr/dataset/test/high-res/"+path+"/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_Y[folder] = files
        
    return test_X,test_Y

In [21]:
def read_test_data():
    test_X = {}
    test_Y = {}
    folders = glob("/home/mri/Documents/dlsr/dataset/test/low-res/"+scale+"x/*/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_X[folder] = files
        
    folders = glob("/home/mri/Documents/dlsr/dataset/test/high-res/*/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_Y[folder] = files
        
    return test_X,test_Y

In [8]:
def gaussian_kernel(win_size, win_sigma):
    m = (win_size - 1.)/2.
    win = torch.arange(-m,m+1)
    win = np.exp(-(win**2)/(2*win_sigma**2))
    win_sum = win.sum()
    if win_sum != 0:
        win /= win_sum
    kernel = torch.outer(win,win)
    return kernel

def ssim(X, Y, win, K, max_val):
    C1 = (K[0] * max_val) ** 2
    C2 = (K[1] * max_val) ** 2
    win = win.to(X.device, dtype=X.dtype)
    
    mu_x = F.conv2d(F.pad(X,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0)
    mu_y = F.conv2d(F.pad(Y,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0)
    mu_xy = mu_x*mu_y
    mu_x2 = mu_x**2
    mu_y2 = mu_y**2
    
    sigma_x2 = (F.conv2d(F.pad(X*X,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0) - mu_x2)
    sigma_y2 = (F.conv2d(F.pad(Y*Y,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0) - mu_y2)
    sigma_xy = (F.conv2d(F.pad(X*Y,(5,5,5,5),'replicate'), weight = win, stride = 1, padding = 0) - mu_xy)
    
    ssim_map = ((2 * mu_xy + C1)*(2 * sigma_xy + C2)) / ((mu_x2 + mu_y2 + C1)*(sigma_x2 + sigma_y2 + C2))
    
    ssim = torch.flatten(ssim_map, 2).mean(-1)
    return ssim

In [22]:
# ls = ['HRCEST07_1', '20220205_135439_DLSRCEST_C164R1_1_5_26_CWCEST_0.6uT_cor_-1p5', 
#        'HRCEST07_5', 'HRCEST07_3', 
#        '20200909_190127_DLCEST_C48_L1_5xFAD_1_17_6_10_7_11_examp_10_CWCEST_0.6uT_front', 
#        '20220205_135439_DLSRCEST_C164R1_1_5_23_CWCEST_0.6uT_sag_0', 
#        'HRCEST07_4', 
#        '20220117_113031_SeWeon_7_Day26_L1_1_24_8_cestRARE_fullZ_0.8uT', 
#        '20201201_114747_ICH_mouse_20201124_C78M2_day1_1_13_10_CWCEST_0.8uT_9696', 
#        'HRCEST07_2']
ls = [1]
os.mkdir(f"new-results/{scale}x")
for i in range(len(ls)):
    # path = ls[i]
    test_X, test_Y = read_test_data()
    test_data = dataset(test_X, test_Y)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
    model.eval()
    psnr_list = []
    low_psnr_list = []
    low_ssim_list = []
    ssim_list = []
    win = gaussian_kernel(11,1.5)
    count = 1
    with torch.no_grad():
        for image_data, label  in test_loader:
            win = win.repeat([label.shape[1]] + [1] * (len(label.shape) - 1))
            data = image_data.to(device)
            label = label.to(device)
            outputs = model(data)
            low_psnr = psnr(data, label)
            low_psnr_list.append(low_psnr)
            single_psnr = psnr(outputs, label)
            psnr_list.append(single_psnr)
            low_ssim = torch.Tensor.cpu(ssim(data,label,win, (0.01,0.03),1))
            low_ssim_list.append(low_ssim)
            single_ssim = torch.Tensor.cpu(ssim(outputs,label,win, (0.01,0.03),1))
            ssim_list.append(single_ssim)
        psnr_tr = pd.DataFrame({'low_psnr':low_psnr_list,'single_psnr':psnr_list, 'low_ssim':low_ssim_list, 'single_ssim':ssim_list},columns={'low_psnr','single_psnr','low_ssim','single_ssim'})
        psnr_tr.to_csv(f"new-results/{scale}x/results.csv")

Reading images to memory


In [None]:
model.eval()
running_psnr = 0.0
running_psnr_list = []
low_running_psnr = 0.0
low_running_psnr_list = []
low_ssim = 0.0
low_ssim_list = []
running_ssim = 0.0
running_ssim_list = []
win = gaussian_kernel(11,1.5)

with torch.no_grad():
    for image_data, label  in test_loader:
        win = win.repeat([label.shape[1]] + [1] * (len(label.shape) - 1))
        data = image_data.to(device)
        label = label.to(device)
        outputs = model(data)
        low_batch_psnr = psnr(data, label)
        low_running_psnr_list.append(low_batch_psnr)
        low_running_psnr += low_batch_psnr
        low_batch_ssim = torch.Tensor.cpu(ssim(data,label,win, (0.01,0.03),1))
        low_ssim_list.append(low_batch_ssim)
        low_ssim += low_batch_ssim.mean()
        batch_psnr = psnr(label, outputs)
        running_psnr_list.append(batch_psnr)
        running_psnr += batch_psnr
        batch_ssim = torch.Tensor.cpu(ssim(outputs,label,win, (0.01,0.03),1))
        running_ssim_list.append(batch_ssim)
        running_ssim += batch_ssim.mean()
        outputs = outputs.cpu()
final_ssim = running_ssim/len(test_loader)
final_low_ssim = low_ssim/len(test_loader)
final_low_psnr = low_running_psnr/len(test_loader)
final_psnr = running_psnr/len(test_loader)

print("Validation Low-res-PSNR: ", final_low_psnr)
print("Validation Low-Res SSIM: ", final_low_ssim)
print("Validation PSNR: ", final_psnr)
print("Validation SSIM: ", final_ssim)

# psnr_tr = pd.DataFrame({'low_psnr':low_running_psnr_list,'single_psnr':running_psnr_list, 'low_ssim':low_ssim_list, 'single_ssim':running_ssim_list},columns={'low_psnr','single_psnr','low_ssim','single_ssim'})
# psnr_tr.to_csv('outputs/rcab-gca/pretrain-model-single/3x/psnr_ssim.csv')

## Writing high-res and low-res as matfile

In [None]:
scale = '8'

def read_test_data(path):
    test_X = {}
    test_Y = {}
    folders = glob("/home/mri/Documents/dlsr/dataset/test/low-res/"+scale+"x/"+path+"/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_X[folder] = files
        
    folders = glob("/home/mri/Documents/dlsr/dataset/test/high-res/"+path+"/data")
    for folder in folders:
        files = sorted(os.listdir(folder))
        test_Y[folder] = files
        
    return test_X,test_Y

In [None]:
ls = ['HRCEST07_1', '20220205_135439_DLSRCEST_C164R1_1_5_26_CWCEST_0.6uT_cor_-1p5', 
       'HRCEST07_5', 'HRCEST07_3', 
       '20200909_190127_DLCEST_C48_L1_5xFAD_1_17_6_10_7_11_examp_10_CWCEST_0.6uT_front', 
       '20220205_135439_DLSRCEST_C164R1_1_5_23_CWCEST_0.6uT_sag_0', 
       'HRCEST07_4', 
       '20220117_113031_SeWeon_7_Day26_L1_1_24_8_cestRARE_fullZ_0.8uT', 
       '20201201_114747_ICH_mouse_20201124_C78M2_day1_1_13_10_CWCEST_0.8uT_9696', 
       'HRCEST07_2']

# os.mkdir(f"python-mat/low-res/{scale}x")

for i in range(len(ls)):
    path = ls[i]
    test_X, test_Y = read_test_data(path)
    test_data = dataset(test_X, test_Y)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
    count = 1
    images= []
    labels = []
    for image_data, label  in test_loader:
        image = image_data.squeeze(dim = 0)
        print(image.type())
        label = label.squeeze(dim = 0)
        image = np.array(image)
        label = np.array(label)
        if count == 1:
            labels = label
            images = image
        else:
            labels = np.concatenate((labels, label), axis=0)
            images = np.concatenate((images, image), axis=0)
        count +=1
    # high_path = f"python-mat/high-res/{path}.mat"
    # low_path = f"python-mat/low-res/{scale}x/{path}.mat"
    # hdict = {'img': labels}
    # ldict = {'img': images}
    # # savemat(high_path, hdict)
    # savemat(low_path, ldict)


























































## Evaluation

In [None]:
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
    macs, params = get_model_complexity_info(Net, (1, 96, 96), as_strings=True,
                                           print_per_layer_stat=True, verbose=True, ignore_modules=[torch.nn.AvgPool2d])
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

## Legacy code


In [None]:
k = 0
act = activation['down3'].squeeze()
print(act.size())
print((act[k].size()))
fig,axarr = plt.subplots(act.size(0)//4,4,figsize=(15,160))
for i in range(act.size(0)//4):
    for j in range(4):
        axarr[i,j].imshow(act[k].detach().cpu().numpy())
        k+=1

In [None]:
activation = {}
def get_activation(name):
    def hook(model, inp, out):
        activation[name] = out.detach()
    return hook

In [None]:
model.down3.register_forward_hook(get_activation('down3'))
data = None
for img, lbl,edge in test_loader:
    data = img.to(device)
output = model(data)
print(data.size())
print(output.size())

In [None]:
device = 'cuda'
model = UNet().to(device)
model.load＿state_dict(torch.load("abstract/outputs/single-unet-np/best-model.pt"))

In [None]:
# class Net(nn.Module):
#     def __init__(self):
#         super().__init__()
        
#         self.inc = ConvReLU(1,64)
#         self.cas = CasBlock(64)
#         self.down1 = Down(64,128)
#         self.down2 = Down(128,256)
#         self.up1 = Up(256,128)
#         self.up2 = Up(128,64)
#         self.conv = SingleConv(64,64)
#         self.out = SingleConv(64,1)
        
#         self.down1px0      = DownProjection(64,64,2)
#         self.down2px0      = DownProjection(64,128,4)
#         self.up1px0        = DownProjection(64,256,2)
#         self.up2px0        = DownProjection(64,128,1)
        
#         self.down2px1 = DownProjection(64,128,4)
#         self.up1px1   = DownProjection(64,256,2)
#         self.up2px1   = DownProjection(64,128,1)
        
#         self.up1px2   = DownProjection(128,256,1)
#         self.up2px2   = UpProjection(128,128,2)
#         self.singleconvpx2 = UpProjection(128,64,2) 
        
#         self.up2px3   = UpProjection(256,128,4)
#         self.singleconvpx3 = UpProjection(256,64,4)
        
#         self.singleconvpx4 = UpProjection(128,64,2)
        
#     def forward(self, x):
#         #in
#         x0 = self.inc(x)
        
#         #casblock
#         x1 = self.cas(x0)
        
#         #down1
#         xp1 = self.down1px0(x0)
#         x2 = self.down1(x1, xp1)
        
#         #down2
#         x0p = self.down2px0(x0)
#         x1p = self.down2px1(x1)
#         xp2 = x0p+x1p
#         x3 = self.down2(x2,xp2)
        
#         #up1
#         x0p = self.up1px0(x0)
#         x1p = self.up1px1(x1)
#         x2p = self.up1px2(x2)
#         xp3 = x0p+x1p+x2p
#         x4 = self.up1(x3,xp3)
        
#         #up2
#         x0p = self.up2px0(x0) 
#         x1p = self.up2px1(x1)
#         x2p = self.up2px2(x2)
#         x3p = self.up2px3(x3)
#         xp4 = x0p+x1p+x2p+x3p
#         x5 = self.up2(x4, xp4)
        
#         #singleconv
#         x0p = x0
#         x1p = x1
#         x2p = self.singleconvpx2(x2)
#         x3p = self.singleconvpx3(x3)
#         x4p = self.singleconvpx4(x4)
#         x6 = x0p+x1p+x2p+x3p+x4p+x5
#         x6 = self.conv(x6)
        
#         #out
#         out = self.out(x6)
#         return out

In [None]:
# class Down(nn.Module):

#     def __init__(self, in_channels, out_channels):
       
#         super().__init__()
#         self.downsample = nn.MaxPool2d(2)
#         self.casblock = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=1),
#             CasBlock(out_channels)
#         )

#     def forward(self, x, xp):
#         x = self.downsample(x)
#         return self.casblock(x+xp)

# class Up(nn.Module):

#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         self.up = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
#         self.casblock = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=1),
#             CasBlock(out_channels)
#         )

#     def forward(self, x, xp):
#         x = self.up(x)
#         return self.casblock(x+xp)
    
# class DownProjection(nn.Module):
#     def __init__(self, in_channels, out_channels, scale):
#         super().__init__()
#         self.downproject = nn.Sequential(
#             nn.MaxPool2d(scale),
#             nn.Conv2d(in_channels, out_channels, kernel_size=1)
#         )
    
#     def forward(self, x):
#         return self.downproject(x)

# class UpProjection(nn.Module):
#     def __init__(self, in_channels, out_channels, scale):
#         super().__init__()
#         self.upproject = nn.Sequential(
#             nn.Upsample(scale_factor=scale, mode='bicubic', align_corners=True),
#             nn.Conv2d(in_channels, out_channels, kernel_size=1)
#         )
        
#     def forward(self, x):
#         return self.upproject(x)