In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm
from kymatio.torch import Scattering2D
from torch.optim.lr_scheduler import CosineAnnealingLR


import pickle

import csv

# Some Utilities

In [None]:
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)
def fftshift2d(x, ifft=False):
    assert (len(x.shape) == 2) and all([(s % 2 == 1) for s in x.shape])
    s0 = (x.shape[0] // 2) + (0 if ifft else 1)
    s1 = (x.shape[1] // 2) + (0 if ifft else 1)
    x = np.concatenate([x[s0:, :], x[:s0, :]], axis=0)
    x = np.concatenate([x[:, s1:], x[:, :s1]], axis=1)
    return x
def fftshift3d(x, ifft):
    assert len(x.shape) == 4
    s0 = (x.shape[2] // 2) + (0 if ifft else 1)
    s1 = (x.shape[3] // 2) + (0 if ifft else 1)
    x = torch.cat([x[:,:, s0:, :], x[:,:, :s0, :]], dim=2)
    x = torch.cat([x[:,:, :, s1:], x[:,:, :, :s1]], dim=3)
    return x

def get_reduced_dataset(images, specs, fraction, datasetclass):
    n_total=len(images)
    n_subset=int(n_total*fraction)
    indices=np.random.choice(n_total,n_subset,replace=False)
    subset_images = images[indices]
    subset_specs = specs[indices]
    return datasetclass(subset_images, subset_specs)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Loading Datasets obtained via the n2n preprocessing

In [None]:
ruta = "C:/Users/javit/Desktop/N2N/datasets/ixi_train.pkl"
ruta_test = "C:/Users/javit/Desktop/N2N/datasets/ixi_valid.pkl"
img, spec = load_pkl(ruta)
img=img[:,:-1,:-1] #images are now 255,255
img = img.astype(np.float32) / 255.0 - 0.5 # normalize and make sure they are in range [-.5,.5]
test_img, test_spec = load_pkl(ruta_test)
test_img=test_img[:,:-1,:-1]
test_img=test_img.astype(np.float32) / 255.0 - 0.5

## Bernoulli masking corruption

In [None]:
p_at_edge=0.025
h = [s // 2 for s in (255,255)] #255
r = [np.arange(s, dtype=np.float32) - h for s, h in zip((255,255), h)]
r = [x ** 2 for x in r]
r = (r[0][:, np.newaxis] + r[1][np.newaxis, :]) ** .5
m = (p_at_edge ** (1./h[1])) ** r
bern_mask = m

def corrupt_data_binary(img, spec):
    global bern_mask
    mask = bern_mask
    # print('Bernoulli probability at edge = %.5f' % mask[h[0], 0])
    # print('Average Bernoulli probability = %.5f' % np.mean(mask))
    keep = (np.random.uniform(0.0, 1.0, size=spec.shape)**2 < mask)
    keep = keep & keep[::-1, ::-1]
    sval = spec * keep
    smsk = keep.astype(np.float32)
    spec = fftshift2d(sval / (mask + ~keep), ifft=True) # Add 1.0 to not-kept values to prevent div-by-zero.
    img = np.real(np.fft.ifft2(spec)).astype(np.float32)
    return img, sval, smsk

class BinaryDenoisingDataset(Dataset):
    def __init__(self,clean_images,clean_specs,corrupt_fn=corrupt_data_binary):
        super(BinaryDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.clean_specs = clean_specs
        self.corrupt_fn = corrupt_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]
        spec_clean = self.clean_specs[idx]
        # Corrupt data
        cimg, cspec, cmask = self.corrupt_fn(img_clean,spec_clean)
        img_noisy = torch.tensor(cimg,dtype=torch.float32)
        spec_noisy = torch.tensor(cspec,dtype=torch.complex64)
        mask = torch.tensor(cmask,dtype=torch.float32)
        img_clean = torch.tensor(img_clean,dtype=torch.float32)
        spec_clean = torch.tensor(spec_clean,dtype=torch.complex64)
        return img_clean.unsqueeze(0), spec_clean.unsqueeze(0), img_noisy.unsqueeze(0), spec_noisy.unsqueeze(0), mask.unsqueeze(0)

## Rician Noise Corruption

In [None]:
def corrupt_data_rician(img,noise_percent):
    sigma =(noise_percent/100)*img.max().item()
    noise1 = np.random.normal(0,sigma,img.shape)
    noise2 = np.random.normal(0,sigma,img.shape)
    noisy_img = np.sqrt((img+noise1)**2+noise2**2)
    return noisy_img
class RicianDenoisingDataset(Dataset):
    def __init__(self,clean_images,t=64,corrupt_fn=corrupt_data_rician,augment_fn=None):
        super(RicianDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]+0.5 #now img in[0,1] range
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg = self.corrupt_fn(img_clean,11)
        img_noisy = torch.tensor(cimg,dtype=torch.float32).clamp(0,1)
        img_clean = torch.tensor(img_clean,dtype=torch.float32).clamp(0,1)
        return img_clean.unsqueeze(0)-0.5, img_noisy.unsqueeze(0)-0.5 #Again in [-.5,.5] range

# Model Definitions

## UNet architecture

In [None]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()

        #Encoder
        self.conv0 = nn.Conv2d(1, 48, 3, stride=1, padding=1)
        self.conv1 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.conv6 = nn.Conv2d(48, 48, 3, stride=1, padding=1)

        #Decoder
        self.upsample5 = nn.Upsample(scale_factor=2)
        self.deconv5a = nn.Conv2d(96, 96, 3, stride=1, padding=1)
        self.deconv5b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample4 = nn.Upsample(scale_factor=2)
        self.deconv4a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv4b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample3 = nn.Upsample(scale_factor=2)
        self.deconv3a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv3b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample2 = nn.Upsample(scale_factor=2)
        self.deconv2a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv2b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)
        self.deconv1a = nn.Conv2d(97, 64, 3, stride=1, padding=1)
        self.deconv1b = nn.Conv2d(64, 32, 3, stride=1, padding=1)
        self.deconv1c = nn.Conv2d(32, 1, 3, stride=1, padding=1)


        
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = F.leaky_relu(self.conv0(original_in))
        x = F.leaky_relu(self.conv1(x))
        x = self.pool1(x)
        pool1 = x
        x = F.leaky_relu(self.conv2(x))
        x = self.pool2(x)
        pool2 = x
        x = F.leaky_relu(self.conv3(x))
        x = self.pool3(x)
        pool3 = x
        x = F.leaky_relu(self.conv4(x))
        x = self.pool4(x)
        pool4 = x
        x = F.leaky_relu(self.conv5(x))
        x = self.pool5(x)
        x = F.leaky_relu(self.conv6(x))
        x = self.upsample5(x)
        x = torch.cat((x, pool4), 1)
        x = F.leaky_relu(self.deconv5a(x))
        x = F.leaky_relu(self.deconv5b(x))
        x = self.upsample4(x)
        x = torch.cat((x, pool3), 1)
        x = F.leaky_relu(self.deconv4a(x))
        x = F.leaky_relu(self.deconv4b(x))
        x = self.upsample3(x)
        x = torch.cat((x, pool2), 1)
        x = F.leaky_relu(self.deconv3a(x))
        x = F.leaky_relu(self.deconv3b(x))
        x = self.upsample2(x)
        x = torch.cat((x, pool1), 1)
        x = F.leaky_relu(self.deconv2a(x))
        x = F.leaky_relu(self.deconv2b(x))
        x = self.upsample1(x)
        x = torch.cat((x, original_in), 1)
        x = F.leaky_relu(self.deconv1a(x))
        x = F.leaky_relu(self.deconv1b(x))
        x = self.deconv1c(x)
        return x[:,:,:-1,:-1]

In [None]:
model = Unet().to(device)
summary(model, (1, 255, 255))

## CNNDMRI architecture

In [None]:
class CNNDMRI(nn.Module):
    def __init__(self):
        super(CNNDMRI, self).__init__()

        #Initial convolutions
        self.conv1 = nn.Conv2d(1,64,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,padding=1)

        # Encoder layers
        self.enc1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)
        self.enc2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)

        #Residual blocks
        self.res1conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.res1conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.res2conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.res2conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.res3conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.res3conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        self.res4conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.res4conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(128)
        
        # Decoder layers
        self.dec1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,padding=1,output_padding=1)
        self.dec2 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2,padding=1)

        #Final convolution
        self.out = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        
        
    def forward(self, x):
        original_noisy = x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #Encoder
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        #Residual blocks
        res1 = x
        x = self.res1conv1(x)
        x = F.relu(self.bn1(x))
        x = self.res1conv2(x)
        x = self.bn2(x)
        x += res1

        res2 = x
        x = self.res2conv1(x)
        x = F.relu(self.bn3(x))
        x = self.res2conv2(x)
        x = self.bn4(x)
        x += res2

        res3 = x
        x = self.res3conv1(x)
        x = F.relu(self.bn5(x))
        x = self.res3conv2(x)
        x = self.bn6(x)
        x += res3

        res4 = x
        x = self.res4conv1(x)
        x = F.relu(self.bn7(x))
        x = self.res4conv2(x)
        x = self.bn8(x)
        x += res4
        # Decoder
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        #Final convolution
        x = self.out(x)
        x += original_noisy
        return x

In [None]:
model = CNNDMRI().to(device)
summary(model, (1, 255, 255))

## WST autoencoder architecture

In [None]:
class WSTAutoencoder(nn.Module):
    def __init__(self):
        super(WSTAutoencoder, self).__init__()
        self.scattering2d = Scattering2D(J=2, L=8, shape=(256,256)) # [81,64,64]

        #Decoder
        self.conv0 = nn.Conv2d(81,64,3,stride=1,padding=1)
        self.bn0 = nn.BatchNorm2d(64)

        self.res1conv1 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res1bn1 = nn.BatchNorm2d(64)
        self.res1conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res1bn2 = nn.BatchNorm2d(64)

        self.res2conv1 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res2bn1 = nn.BatchNorm2d(64)
        self.res2conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res2bn2 = nn.BatchNorm2d(64)

        self.upsample1 = nn.Upsample(scale_factor=2) #[64,128,128]
        self.conv1 = nn.Conv2d(64,32,3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.res3conv1 = nn.Conv2d(32,32,3,stride=1,padding=1)
        self.res3bn1 = nn.BatchNorm2d(32)
        self.res3conv2 = nn.Conv2d(32,32,3,stride=1,padding=1)
        self.res3bn2 = nn.BatchNorm2d(32)

        self.upsample2 = nn.Upsample(scale_factor=2) #[32,256,256]
        self.conv2 = nn.Conv2d(33,16,3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(16)

        self.res4conv1 = nn.Conv2d(16,16,3,stride=1,padding=1)
        self.res4bn1 = nn.BatchNorm2d(16)
        self.res4conv2 = nn.Conv2d(16,16,3,stride=1,padding=1)
        self.res4bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16,9,3,stride=1,padding=1)
        self.bn3 = nn.BatchNorm2d(9)

        self.res5conv1 = nn.Conv2d(9,9,3,stride=1,padding=1)
        self.res5bn1 = nn.BatchNorm2d(9)
        self.res5conv2 = nn.Conv2d(9,9,3,stride=1,padding=1)
        self.res5bn2 = nn.BatchNorm2d(9)

        self.conv4 = nn.Conv2d(9,1,3,stride=1,padding=1)
        # self.bn4 = nn.BatchNorm2d(1)
    
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = self.scattering2d(original_in).squeeze(1)
        
        x = self.conv0(x)
        x = F.leaky_relu(self.bn0(x))

        resblock1_input =x
        x = self.res1conv1(x)
        x = F.leaky_relu(self.res1bn1(x))

        x = self.res1conv2(x)
        x = F.leaky_relu(self.res1bn2(x))

        x = x + resblock1_input
        resblock2_input = x
        x = self.res2conv1(x)

        x = self.res2conv2(x)
        x = F.leaky_relu(self.res2bn2(x))

        x = x + resblock2_input
        x = self.upsample1(x)
        x = self.conv1(x)
        x = F.leaky_relu(self.bn1(x))

        resblock3_input = x
        x = self.res3conv1(x)
        x = F.leaky_relu(self.res3bn1(x))

        x = self.res3conv2(x)
        x = F.leaky_relu(self.res3bn2(x))
        x = x + resblock3_input

        x = self.upsample2(x)
        x = torch.cat((x, original_in), dim=1)
        x = self.conv2(x)
        x = F.leaky_relu(self.bn2(x))

        resblock4_input = x
        x = self.res4conv1(x)
        x = F.leaky_relu(self.res4bn1(x))
        x = self.res4conv2(x)
        x = F.leaky_relu(self.res4bn2(x))
        x = x + resblock4_input
        
        x = self.conv3(x)
        x = F.leaky_relu(self.bn3(x))

        resblock5_input = x
        x = self.res5conv1(x)
        x = F.leaky_relu(self.res5bn1(x))
        x = self.res5conv2(x)
        x = F.leaky_relu(self.res5bn2(x))
        x = x + resblock5_input
        x = self.conv4(x)
        # x = F.sigmoid(x)-0.5
        return x[:,:,:-1,:-1]

In [None]:
model = WSTAutoencoder().to(device)
summary(model, (1, 255, 255))

# Binary mask denoising

## Load Dataset and Train function

In [None]:
binary_dataset= BinaryDenoisingDataset(img,spec)
binary_test_dataset = BinaryDenoisingDataset(test_img,test_spec)
binary_dataloader = DataLoader(binary_dataset, batch_size=32, shuffle=True)
binary_dataloader_test = DataLoader(binary_test_dataset, batch_size=32, shuffle=False)

In [None]:
def binary_train(net, trainLoader,testLoader, loss_fn, optimizer, scheduler, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                spec_noisy = spec_noisy.to(device)
                mask = mask.to(device)
                optimizer.zero_grad()
                denoised = net(img_noisy)
                denoised_spec = torch.fft.fft2(denoised)
                denoised_spec = fftshift3d(denoised_spec, ifft=False)
                spec_mask = mask.type(torch.complex64)
                denoised_spec = spec_noisy * spec_mask + denoised_spec * (1 - spec_mask)
                outputs = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True)))
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    spec_noisy = spec_noisy.to(device)
                    mask = mask.to(device)
                    denoised = net(img_noisy)
                    denoised_spec = torch.fft.fft2(denoised)
                    denoised_spec = fftshift3d(denoised_spec, ifft=False)
                    spec_mask = mask.type(torch.complex64)
                    denoised_spec = spec_noisy * spec_mask + denoised_spec * (1 - spec_mask)
                    outputs = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True)))
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss                          
                

## Unet Train

In [None]:
binaryunet = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binaryunet.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_unet_train_loss,binary_unet_test_loss = binary_train(binaryunet, binary_dataloader, binary_dataloader_test, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_with_freq_enh.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_unet_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_unet_train_loss, label='train loss')
plt.plot(binary_unet_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')


## CNNDMRI train

In [None]:
binarycnndmri = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarycnndmri.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_cnndmri_train_loss,binary_cnndmri_test_loss = binary_train(binarycnndmri, binary_dataloader, binary_dataloader_test, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_with_freq_enh.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_cnndmri_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_cnndmri_train_loss, label='train loss')
plt.plot(binary_cnndmri_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

## WSTAutoencoder train

In [None]:
binarywsta = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarywsta.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_wsta_train_loss,binary_wsta_test_loss = binary_train(binarywsta, binary_dataloader, binary_dataloader_test, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_with_freq_enh.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_wsta_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_wsta_train_loss, label='train loss')
plt.plot(binary_wsta_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

## Three models comparison

In [None]:
plt.figure()
plt.plot(binary_wsta_train_loss, label='WSTA architecture')
plt.plot(binary_unet_train_loss, label='UNet architecture')
plt.plot(binary_cnndmri_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Rician Noise Denoising

## Dataset and train function

In [None]:
rici_dataset = RicianDenoisingDataset(img,spec)
rici_test_dataset = RicianDenoisingDataset(test_img,test_spec)
rici_dataloader = DataLoader(rici_dataset, batch_size=32, shuffle=True)
rici_dataloader_test = DataLoader(rici_test_dataset, batch_size=32, shuffle=False)

In [None]:
def ricitrain(net, trainLoader,testLoader, loss_fn, optimizer, scheduler, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean, img_noisy = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                optimizer.zero_grad()
                outputs = net(img_noisy)
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean, img_noisy= data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss         

## Unet

In [None]:
ricianunet = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricianunet.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_unet_train_loss,rici_unet_test_loss = ricitrain(ricianunet, rici_dataloader, rici_dataloader_test, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_unet_complete.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_unet_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_unet_train_loss, label='train loss')
plt.plot(rici_unet_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

## CNNDMRI

In [None]:
ricicnndmri = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricicnndmri.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_cnndmri_train_loss,rici_cnndmri_test_loss = ricitrain(ricicnndmri, rici_dataloader, rici_dataloader_test, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_cnndmri_complete.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_cnndmri_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_cnndmri_train_loss, label='train loss')
plt.plot(rici_cnndmri_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

## WSTAutoencoder

In [None]:
riciwsta = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciwsta.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_wsta_train_loss,rici_wsta_test_loss = ricitrain(riciwsta, rici_dataloader, rici_dataloader_test, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_wsta_complete.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_wsta_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_wsta_train_loss, label='train loss')
plt.plot(rici_wsta_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

## Three model comparison

In [None]:
plt.figure()
plt.plot(rici_wsta_train_loss, label='WSTA architecture')
plt.plot(rici_unet_train_loss, label='UNet architecture')
plt.plot(rici_cnndmri_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Now lets test with reduced datasets

In [None]:
binary_dataset_75 = get_reduced_dataset(img,spec,0.75,BinaryDenoisingDataset)
binary_test_dataset_75 = get_reduced_dataset(test_img,test_spec,0.75,BinaryDenoisingDataset)
binary_dataloader_75 = DataLoader(binary_dataset_75, batch_size=32, shuffle=True)
binary_dataloader_test_75 = DataLoader(binary_test_dataset_75, batch_size=32, shuffle=False)

binary_dataset_50 = get_reduced_dataset(img,spec,0.5,BinaryDenoisingDataset)
binary_test_dataset_50 = get_reduced_dataset(test_img,test_spec,0.5,BinaryDenoisingDataset)
binary_dataloader_50 = DataLoader(binary_dataset_50, batch_size=32, shuffle=True)
binary_dataloader_test_50 = DataLoader(binary_test_dataset_50, batch_size=32, shuffle=False)

binary_dataset_25 = get_reduced_dataset(img,spec,0.25,BinaryDenoisingDataset)
binary_test_dataset_25 = get_reduced_dataset(test_img,test_spec,0.25,BinaryDenoisingDataset)
binary_dataloader_25 = DataLoader(binary_dataset_25, batch_size=32, shuffle=True)
binary_dataloader_test_25 = DataLoader(binary_test_dataset_25, batch_size=32, shuffle=False)

binary_dataset_10 = get_reduced_dataset(img,spec,0.1,BinaryDenoisingDataset)
binary_test_dataset_10 = get_reduced_dataset(test_img,test_spec,0.1,BinaryDenoisingDataset)
binary_dataloader_10 = DataLoader(binary_dataset_10, batch_size=32, shuffle=True)
binary_dataloader_test_10 = DataLoader(binary_test_dataset_10, batch_size=32, shuffle=False)

rici_dataset_75 = get_reduced_dataset(img,spec,0.75,RicianDenoisingDataset)
rici_test_dataset_75 = get_reduced_dataset(test_img,test_spec,0.75,RicianDenoisingDataset)
rici_dataloader_75 = DataLoader(rici_dataset_75, batch_size=32, shuffle=True)
rici_dataloader_test_75 = DataLoader(rici_test_dataset_75, batch_size=32, shuffle=False)

rici_dataset_50 = get_reduced_dataset(img,spec,0.5,RicianDenoisingDataset)
rici_test_dataset_50 = get_reduced_dataset(test_img,test_spec,0.5,RicianDenoisingDataset)
rici_dataloader_50 = DataLoader(rici_dataset_50, batch_size=32, shuffle=True)
rici_dataloader_test_50 = DataLoader(rici_test_dataset_50, batch_size=32, shuffle=False)

rici_dataset_25 = get_reduced_dataset(img,spec,0.25,RicianDenoisingDataset)
rici_test_dataset_25 = get_reduced_dataset(test_img,test_spec,0.25,RicianDenoisingDataset)
rici_dataloader_25 = DataLoader(rici_dataset_25, batch_size=32, shuffle=True)
rici_dataloader_test_25 = DataLoader(rici_test_dataset_25, batch_size=32, shuffle=False)

rici_dataset_10 = get_reduced_dataset(img,spec,0.1,RicianDenoisingDataset)
rici_test_dataset_10 = get_reduced_dataset(test_img,test_spec,0.1,RicianDenoisingDataset)
rici_dataloader_10 = DataLoader(rici_dataset_10, batch_size=32, shuffle=True)
rici_dataloader_test_10 = DataLoader(rici_test_dataset_10, batch_size=32, shuffle=False)

## Binary mask denoising

In [None]:
binaryunet75 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binaryunet75.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_unet75_train_loss,binary_unet75_test_loss = binary_train(binaryunet75, binary_dataloader_75, binary_dataloader_test_75, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_75.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_unet75_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_unet75_train_loss, label='train loss')
plt.plot(binary_unet75_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')


In [None]:
binarycnndmri75 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarycnndmri75.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_cnndmri75_train_loss,binary_cnndmri75_test_loss = binary_train(binarycnndmri75, binary_dataloader_75, binary_dataloader_test_75, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_75.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_cnndmri75_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_cnndmri75_train_loss, label='train loss')
plt.plot(binary_cnndmri75_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta75 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarywsta75.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_wsta75_train_loss,binary_wsta75_test_loss = binary_train(binarywsta75, binary_dataloader_75, binary_dataloader_test_75, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_75.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_wsta75_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_wsta75_train_loss, label='train loss')
plt.plot(binary_wsta75_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binary_wsta75_train_loss, label='WSTA architecture')
plt.plot(binary_unet75_train_loss, label='UNet architecture')
plt.plot(binary_cnndmri75_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .75 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
binaryunet50 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binaryunet50.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_unet50_train_loss,binary_unet50_test_loss = binary_train(binaryunet50, binary_dataloader_50, binary_dataloader_test_50, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_50.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_unet50_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_unet50_train_loss, label='train loss')
plt.plot(binary_unet50_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri50 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarycnndmri50.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_cnndmri50_train_loss,binary_cnndmri50_test_loss = binary_train(binarycnndmri50, binary_dataloader_50, binary_dataloader_test_50, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_50.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_cnndmri50_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_cnndmri50_train_loss, label='train loss')
plt.plot(binary_cnndmri50_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta50 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarywsta50.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_wsta50_train_loss,binary_wsta50_test_loss = binary_train(binarywsta50, binary_dataloader_50, binary_dataloader_test_50, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_50.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_wsta50_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_wsta50_train_loss, label='train loss')
plt.plot(binary_wsta50_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binary_wsta50_train_loss, label='WSTA architecture')
plt.plot(binary_unet50_train_loss, label='UNet architecture')
plt.plot(binary_cnndmri50_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .50 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
binaryunet25 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binaryunet25.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_unet25_train_loss,binary_unet25_test_loss = binary_train(binaryunet25, binary_dataloader_25, binary_dataloader_test_25, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_25.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_unet25_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_unet25_train_loss, label='train loss')
plt.plot(binary_unet25_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri25 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarycnndmri25.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_cnndmri25_train_loss,binary_cnndmri25_test_loss = binary_train(binarycnndmri25, binary_dataloader_25, binary_dataloader_test_25, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_25.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_cnndmri25_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_cnndmri25_train_loss, label='train loss')
plt.plot(binary_cnndmri25_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta25 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarywsta25.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)    
binary_wsta25_train_loss,binary_wsta25_test_loss = binary_train(binarywsta25, binary_dataloader_25, binary_dataloader_test_25, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_25.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_wsta25_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_wsta25_train_loss, label='train loss')
plt.plot(binary_wsta25_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binary_wsta25_train_loss, label='WSTA architecture')
plt.plot(binary_unet25_train_loss, label='UNet architecture')
plt.plot(binary_cnndmri25_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .25 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
binaryunet10 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binaryunet10.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_unet10_train_loss,binary_unet10_test_loss = binary_train(binaryunet10, binary_dataloader_10, binary_dataloader_test_10, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_10.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_unet10_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_unet10_train_loss, label='train loss')
plt.plot(binary_unet10_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri10 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarycnndmri10.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_cnndmri10_train_loss,binary_cnndmri10_test_loss = binary_train(binarycnndmri10, binary_dataloader_10, binary_dataloader_test_10, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_10.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_cnndmri10_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_cnndmri10_train_loss, label='train loss')
plt.plot(binary_cnndmri10_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta10 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarywsta10.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)    
binary_wsta10_train_loss,binary_wsta10_test_loss = binary_train(binarywsta10, binary_dataloader_10, binary_dataloader_test_10, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_10.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binary_wsta10_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binary_wsta10_train_loss, label='train loss')
plt.plot(binary_wsta10_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binary_wsta10_train_loss, label='WSTA architecture')
plt.plot(binary_unet10_train_loss, label='UNet architecture')
plt.plot(binary_cnndmri10_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .10 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Rician Noise Denoising

In [None]:
riciunet75 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciunet75.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_unet75_train_loss,rici_unet75_test_loss = ricitrain(riciunet75, rici_dataloader_75, rici_dataloader_test_75, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_unet_75.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_unet75_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_unet75_train_loss, label='train loss')
plt.plot(rici_unet75_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri75 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricicnndmri75.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_cnndmri75_train_loss,rici_cnndmri75_test_loss = ricitrain(ricicnndmri75, rici_dataloader_75, rici_dataloader_test_75, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_cnndmri_75.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_cnndmri75_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_cnndmri75_train_loss, label='train loss')
plt.plot(rici_cnndmri75_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta75 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciwsta75.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_wsta75_train_loss,rici_wsta75_test_loss = ricitrain(riciwsta75, rici_dataloader_75, rici_dataloader_test_75, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_wsta_75.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_wsta75_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_wsta75_train_loss, label='train loss')
plt.plot(rici_wsta75_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(rici_wsta75_train_loss, label='WSTA architecture')
plt.plot(rici_unet75_train_loss, label='UNet architecture')
plt.plot(rici_cnndmri75_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .75 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet50 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciunet50.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_unet50_train_loss,rici_unet50_test_loss = ricitrain(riciunet50, rici_dataloader_50, rici_dataloader_test_50, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_unet_50.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_unet50_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_unet50_train_loss, label='train loss')
plt.plot(rici_unet50_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri50 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricicnndmri50.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_cnndmri50_train_loss,rici_cnndmri50_test_loss = ricitrain(ricicnndmri50, rici_dataloader_50, rici_dataloader_test_50, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_cnndmri_50.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_cnndmri50_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_cnndmri50_train_loss, label='train loss')
plt.plot(rici_cnndmri50_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta50 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciwsta50.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_wsta50_train_loss,rici_wsta50_test_loss = ricitrain(riciwsta50, rici_dataloader_50, rici_dataloader_test_50, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_wsta_50.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_wsta50_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_wsta50_train_loss, label='train loss')
plt.plot(rici_wsta50_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(rici_wsta50_train_loss, label='WSTA architecture')
plt.plot(rici_unet50_train_loss, label='UNet architecture')
plt.plot(rici_cnndmri50_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .50 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet25 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciunet25.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_unet25_train_loss,rici_unet25_test_loss = ricitrain(riciunet25, rici_dataloader_25, rici_dataloader_test_25, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_unet_25.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_unet25_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_unet25_train_loss, label='train loss')
plt.plot(rici_unet25_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri25 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricicnndmri25.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_cnndmri25_train_loss,rici_cnndmri25_test_loss = ricitrain(ricicnndmri25, rici_dataloader_25, rici_dataloader_test_25, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_cnndmri_25.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_cnndmri25_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_cnndmri25_train_loss, label='train loss')
plt.plot(rici_cnndmri25_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta25 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciwsta25.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_wsta25_train_loss,rici_wsta25_test_loss = ricitrain(riciwsta25, rici_dataloader_25, rici_dataloader_test_25, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_wsta_25.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_wsta25_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_wsta25_train_loss, label='train loss')
plt.plot(rici_wsta25_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(rici_wsta25_train_loss, label='WSTA architecture')
plt.plot(rici_unet25_train_loss, label='UNet architecture')
plt.plot(rici_cnndmri25_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .25 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet10 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciunet10.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_unet10_train_loss,rici_unet10_test_loss = ricitrain(riciunet10, rici_dataloader_10, rici_dataloader_test_10, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_unet_10.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_unet10_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_unet10_train_loss, label='train loss')
plt.plot(rici_unet10_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri10 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricicnndmri10.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_cnndmri10_train_loss,rici_cnndmri10_test_loss = ricitrain(ricicnndmri10, rici_dataloader_10, rici_dataloader_test_10, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_cnndmri_10.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_cnndmri10_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_cnndmri10_train_loss, label='train loss')
plt.plot(rici_cnndmri10_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta10 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciwsta10.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_wsta10_train_loss,rici_wsta10_test_loss = ricitrain(riciwsta10, rici_dataloader_10, rici_dataloader_test_10, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rici_wsta_10.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(rici_wsta10_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(rici_wsta10_train_loss, label='train loss')
plt.plot(rici_wsta10_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(rici_wsta10_train_loss, label='WSTA architecture')
plt.plot(rici_unet10_train_loss, label='UNet architecture')
plt.plot(rici_cnndmri10_train_loss, label='CNNDMRI architecture')
plt.title('Model comparison .10 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Even less datapoints

In [None]:
binary_dataset_05 = get_reduced_dataset(img,spec,0.05,BinaryDenoisingDataset)
binary_test_dataset_05 = get_reduced_dataset(test_img,test_spec,0.05,BinaryDenoisingDataset)
binary_dataloader_05 = DataLoader(binary_dataset_05, batch_size=32, shuffle=True)
binary_dataloader_test_05 = DataLoader(binary_test_dataset_05, batch_size=32, shuffle=False)

rici_dataset_05 = get_reduced_dataset(img,spec,0.05,RicianDenoisingDataset)
rici_test_dataset_05 = get_reduced_dataset(test_img,test_spec,0.05,RicianDenoisingDataset)
rici_dataloader_05 = DataLoader(rici_dataset_05, batch_size=32, shuffle=True)
rici_dataloader_test_05 = DataLoader(rici_test_dataset_05, batch_size=32, shuffle=False)

binary_dataset_025 = get_reduced_dataset(img,spec,0.025,BinaryDenoisingDataset)
binary_test_dataset_025 = get_reduced_dataset(test_img,test_spec,0.025,BinaryDenoisingDataset)
binary_dataloader_025 = DataLoader(binary_dataset_025, batch_size=32, shuffle=True)
binary_dataloader_test_025 = DataLoader(binary_test_dataset_025, batch_size=32, shuffle=False)

rici_dataset_025 = get_reduced_dataset(img,spec,0.025,RicianDenoisingDataset)
rici_test_dataset_025 = get_reduced_dataset(test_img,test_spec,0.025,RicianDenoisingDataset)
rici_dataloader_025 = DataLoader(rici_dataset_025, batch_size=32, shuffle=True)
rici_dataloader_test_025 = DataLoader(rici_test_dataset_025, batch_size=32, shuffle=False)

binary_dataset_0125 = get_reduced_dataset(img,spec,0.0125,BinaryDenoisingDataset)
binary_test_dataset_0125 = get_reduced_dataset(test_img,test_spec,0.0125,BinaryDenoisingDataset)
binary_dataloader_0125 = DataLoader(binary_dataset_0125, batch_size=32, shuffle=True)
binary_dataloader_test_0125 = DataLoader(binary_test_dataset_0125, batch_size=32, shuffle=False)

rici_dataset_0125 = get_reduced_dataset(img,spec,0.0125,RicianDenoisingDataset)
rici_test_dataset_0125 = get_reduced_dataset(test_img,test_spec,0.0125,RicianDenoisingDataset)
rici_dataloader_0125 = DataLoader(rici_dataset_0125, batch_size=32, shuffle=True)
rici_dataloader_test_0125 = DataLoader(rici_test_dataset_0125, batch_size=32, shuffle=False)

binary_dataset_005 = get_reduced_dataset(img,spec,0.005,BinaryDenoisingDataset)
binary_test_dataset_005 = get_reduced_dataset(test_img,test_spec,0.005,BinaryDenoisingDataset)
binary_dataloader_005 = DataLoader(binary_dataset_005, batch_size=32, shuffle=True)
binary_dataloader_test_005 = DataLoader(binary_test_dataset_005, batch_size=32, shuffle=False)

rici_dataset_005 = get_reduced_dataset(img,spec,0.005,RicianDenoisingDataset)
rici_test_dataset_005 = get_reduced_dataset(test_img,test_spec,0.005,RicianDenoisingDataset)
rici_dataloader_005 = DataLoader(rici_dataset_005, batch_size=32, shuffle=True)
rici_dataloader_test_005 = DataLoader(rici_test_dataset_005, batch_size=32, shuffle=False)



In [None]:
binaryunet05 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binaryunet05.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binaryunet05_train_loss,binaryunet05_test_loss = binary_train(binaryunet05, binary_dataloader_05, binary_dataloader_test_05, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_05.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binaryunet05_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binaryunet05_train_loss, label='train loss')
plt.plot(binaryunet05_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri05 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarycnndmri05.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarycnndmri05_train_loss,binarycnndmri05_test_loss = binary_train(binarycnndmri05, binary_dataloader_05, binary_dataloader_test_05, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_05.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarycnndmri05_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarycnndmri05_train_loss, label='train loss')
plt.plot(binarycnndmri05_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta05 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarywsta05.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarywsta05_train_loss,binarywsta05_test_loss = binary_train(binarywsta05, binary_dataloader_05, binary_dataloader_test_05, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_05.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarywsta05_train_loss, start=1): 
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarywsta05_train_loss, label='train loss')
plt.plot(binarywsta05_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binarywsta05_train_loss, label='WSTA architecture')
plt.plot(binaryunet05_train_loss, label='UNet architecture')
plt.plot(binarycnndmri05_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .05 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
binaryunet025 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binaryunet025.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binaryunet025_train_loss,binaryunet025_test_loss = binary_train(binaryunet025, binary_dataloader_025, binary_dataloader_test_025, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_025.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binaryunet025_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binaryunet025_train_loss, label='train loss')
plt.plot(binaryunet025_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri025 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarycnndmri025.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarycnndmri025_train_loss,binarycnndmri025_test_loss = binary_train(binarycnndmri025, binary_dataloader_025, binary_dataloader_test_025, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_025.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarycnndmri025_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarycnndmri025_train_loss, label='train loss')
plt.plot(binarycnndmri025_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta025 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarywsta025.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarywsta025_train_loss,binarywsta025_test_loss = binary_train(binarywsta025, binary_dataloader_025, binary_dataloader_test_025, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_025.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarywsta025_train_loss, start=1): 
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarywsta025_train_loss, label='train loss')
plt.plot(binarywsta025_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binarywsta025_train_loss, label='WSTA architecture')
plt.plot(binaryunet025_train_loss, label='UNet architecture')
plt.plot(binarycnndmri025_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .025 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
binaryunet0125 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binaryunet0125.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binaryunet0125_train_loss,binaryunet0125_test_loss = binary_train(binaryunet0125, binary_dataloader_0125, binary_dataloader_test_0125, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_0125.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binaryunet0125_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binaryunet0125_train_loss, label='train loss')
plt.plot(binaryunet0125_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri0125 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarycnndmri0125.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarycnndmri0125_train_loss,binarycnndmri0125_test_loss = binary_train(binarycnndmri0125, binary_dataloader_0125, binary_dataloader_test_0125, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_0125.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarycnndmri0125_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarycnndmri0125_train_loss, label='train loss')
plt.plot(binarycnndmri0125_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta0125 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarywsta0125.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarywsta0125_train_loss,binarywsta0125_test_loss = binary_train(binarywsta0125, binary_dataloader_0125, binary_dataloader_test_0125, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_0125.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarywsta0125_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarywsta0125_train_loss, label='train loss')
plt.plot(binarywsta0125_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binarywsta0125_train_loss, label='WSTA architecture')
plt.plot(binaryunet0125_train_loss, label='UNet architecture')
plt.plot(binarycnndmri0125_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .0125 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
binaryunet005 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binaryunet005.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binaryunet005_train_loss,binaryunet005_test_loss = binary_train(binaryunet005, binary_dataloader_005, binary_dataloader_test_005, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_unet_005.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binaryunet005_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binaryunet005_train_loss, label='train loss')
plt.plot(binaryunet005_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarycnndmri005 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarycnndmri005.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarycnndmri005_train_loss,binarycnndmri005_test_loss = binary_train(binarycnndmri005, binary_dataloader_005, binary_dataloader_test_005, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_cnndmri_005.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarycnndmri005_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarycnndmri005_train_loss, label='train loss')
plt.plot(binarycnndmri005_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
binarywsta005 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(binarywsta005.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binarywsta005_train_loss,binarywsta005_test_loss = binary_train(binarywsta005, binary_dataloader_005, binary_dataloader_test_005, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_binary_wsta_005.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(binarywsta005_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(binarywsta005_train_loss, label='train loss')
plt.plot(binarywsta005_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(binarywsta005_train_loss, label='WSTA architecture')
plt.plot(binaryunet005_train_loss, label='UNet architecture')
plt.plot(binarycnndmri005_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .005 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet05 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciunet05.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciunet05_train_loss,riciunet05_test_loss = ricitrain(riciunet05, rici_dataloader_05, rici_dataloader_test_05, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_unet_05.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciunet05_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciunet05_train_loss, label='train loss')
plt.plot(riciunet05_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri05 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(ricicnndmri05.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
ricicnndmri05_train_loss,ricicnndmri05_test_loss = ricitrain(ricicnndmri05, rici_dataloader_05, rici_dataloader_test_05, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_cnndmri_05.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(ricicnndmri05_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(ricicnndmri05_train_loss, label='train loss')
plt.plot(ricicnndmri05_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta05 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciwsta05.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciwsta05_train_loss,riciwsta05_test_loss = ricitrain(riciwsta05, rici_dataloader_05, rici_dataloader_test_05, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_wsta_05.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciwsta05_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciwsta05_train_loss, label='train loss')
plt.plot(riciwsta05_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(riciwsta05_train_loss, label='WSTA architecture')
plt.plot(riciunet05_train_loss, label='UNet architecture')
plt.plot(ricicnndmri05_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .05 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet025 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciunet025.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciunet025_train_loss,riciunet025_test_loss = ricitrain(riciunet025, rici_dataloader_025, rici_dataloader_test_025, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_unet_025.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciunet025_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciunet025_train_loss, label='train loss')
plt.plot(riciunet025_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri025 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(ricicnndmri025.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
ricicnndmri025_train_loss,ricicnndmri025_test_loss = ricitrain(ricicnndmri025, rici_dataloader_025, rici_dataloader_test_025, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_cnndmri_025.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(ricicnndmri025_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(ricicnndmri025_train_loss, label='train loss')
plt.plot(ricicnndmri025_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta025 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciwsta025.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciwsta025_train_loss,riciwsta025_test_loss = ricitrain(riciwsta025, rici_dataloader_025, rici_dataloader_test_025, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_wsta_025.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciwsta025_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciwsta025_train_loss, label='train loss')
plt.plot(riciwsta025_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(riciwsta025_train_loss, label='WSTA architecture')
plt.plot(riciunet025_train_loss, label='UNet architecture')
plt.plot(ricicnndmri025_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .025 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet0125 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciunet0125.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciunet0125_train_loss,riciunet0125_test_loss = ricitrain(riciunet0125, rici_dataloader_0125, rici_dataloader_test_0125, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_unet_0125.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciunet0125_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciunet0125_train_loss, label='train loss')
plt.plot(riciunet0125_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri0125 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(ricicnndmri0125.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
ricicnndmri0125_train_loss,ricicnndmri0125_test_loss = ricitrain(ricicnndmri0125, rici_dataloader_0125, rici_dataloader_test_0125, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_cnndmri_0125.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(ricicnndmri0125_train_loss, start=1):    
#         writer.writerow([i, loss])
plt.figure()
plt.plot(ricicnndmri0125_train_loss, label='train loss')
plt.plot(ricicnndmri0125_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta0125 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciwsta0125.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciwsta0125_train_loss,riciwsta0125_test_loss = ricitrain(riciwsta0125, rici_dataloader_0125, rici_dataloader_test_0125, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_wsta_0125.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciwsta0125_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciwsta0125_train_loss, label='train loss')
plt.plot(riciwsta0125_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(riciwsta0125_train_loss, label='WSTA architecture')
plt.plot(riciunet0125_train_loss, label='UNet architecture')
plt.plot(ricicnndmri0125_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .0125 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
riciunet005 = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciunet005.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciunet005_train_loss,riciunet005_test_loss = ricitrain(riciunet005, rici_dataloader_005, rici_dataloader_test_005, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_unet_005.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciunet005_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciunet005_train_loss, label='train loss')
plt.plot(riciunet005_test_loss, label='test loss')    
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
ricicnndmri005 = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(ricicnndmri005.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
ricicnndmri005_train_loss,ricicnndmri005_test_loss = ricitrain(ricicnndmri005, rici_dataloader_005, rici_dataloader_test_005, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_cnndmri_005.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(ricicnndmri005_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(ricicnndmri005_train_loss, label='train loss')
plt.plot(ricicnndmri005_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
riciwsta005 = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(riciwsta005.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
riciwsta005_train_loss,riciwsta005_test_loss = ricitrain(riciwsta005, rici_dataloader_005, rici_dataloader_test_005, loss_fn, optimizer, scheduler, 25)
# with open('./Results/train_losses_rician_wsta_005.csv', mode='w', newline='') as file:
#     writer = csv.writer(file)
#     writer.writerow(['Epoch', 'Train Loss'])
#     for i, loss in enumerate(riciwsta005_train_loss, start=1):
#         writer.writerow([i, loss])
plt.figure()
plt.plot(riciwsta005_train_loss, label='train loss')
plt.plot(riciwsta005_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(riciwsta005_train_loss, label='WSTA architecture')
plt.plot(riciunet005_train_loss, label='UNet architecture')
plt.plot(ricicnndmri005_train_loss, label='CNN-DMRI architecture')
plt.title('Model comparison .005 dataset')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()