In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchsummary import summary
from torch.optim.lr_scheduler import CosineAnnealingLR

import matplotlib.pyplot as plt
import numpy as np
import os
import time
from tqdm import tqdm
from kymatio.torch import Scattering2D
import nibabel as nib
from scipy.fftpack import fft, ifft, fft2, ifft2, fftshift, ifftshift
import PIL.Image
import pickle

In [None]:
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)
def fftshift2d(x, ifft=False):
    assert (len(x.shape) == 2) and all([(s % 2 == 1) for s in x.shape])
    s0 = (x.shape[0] // 2) + (0 if ifft else 1)
    s1 = (x.shape[1] // 2) + (0 if ifft else 1)
    x = np.concatenate([x[s0:, :], x[:s0, :]], axis=0)
    x = np.concatenate([x[:, s1:], x[:, :s1]], axis=1)
    return x

In [None]:
# ruta = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_train-001.pkl"
# ruta_test = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_valid.pkl"
ruta = "C:/Users/javit/Desktop/N2N/datasets/ixi_train.pkl"
ruta_test = "C:/Users/javit/Desktop/N2N/datasets/ixi_valid.pkl"
img, spec = load_pkl(ruta)
img=img[:,:-1,:-1] #images are now 255,255
img = img.astype(np.float32) / 255.0 - 0.5 # normalize and make sure they are in range [-.5,.5]
test_img, test_spec = load_pkl(ruta_test)
test_img=test_img[:,:-1,:-1]
test_img=test_img.astype(np.float32) / 255.0 - 0.5


p_at_edge=0.025
h = [s // 2 for s in (255,255)] #255
r = [np.arange(s, dtype=np.float32) - h for s, h in zip((255,255), h)]
r = [x ** 2 for x in r]
r = (r[0][:, np.newaxis] + r[1][np.newaxis, :]) ** .5
m = (p_at_edge ** (1./h[1])) ** r
bern_mask = m

In [None]:
def corrupt_data(img, spec):
    global bern_mask
    mask = bern_mask
    # print('Bernoulli probability at edge = %.5f' % mask[h[0], 0])
    # print('Average Bernoulli probability = %.5f' % np.mean(mask))
    keep = (np.random.uniform(0.0, 1.0, size=spec.shape)**2 < mask)
    keep = keep & keep[::-1, ::-1]
    sval = spec * keep
    smsk = keep.astype(np.float32)
    spec = fftshift2d(sval / (mask + ~keep), ifft=True) # Add 1.0 to not-kept values to prevent div-by-zero.
    img = np.real(np.fft.ifft2(spec)).astype(np.float32)
    return img, sval, smsk

In [None]:
class MRIDenoisingDataset(Dataset):
    def __init__(self,clean_images,clean_specs,t=64,corrupt_fn=corrupt_data,augment_fn=None):
        super(MRIDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.clean_specs = clean_specs
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]
        spec_clean = self.clean_specs[idx]
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg, cspec, cmask = self.corrupt_fn(img_clean,spec_clean)
        img_noisy = torch.tensor(cimg,dtype=torch.float32)
        spec_noisy = torch.tensor(cspec,dtype=torch.complex64)
        mask = torch.tensor(cmask,dtype=torch.float32)
        img_clean = torch.tensor(img_clean,dtype=torch.float32)
        spec_clean = torch.tensor(spec_clean,dtype=torch.complex64)
        return img_clean.unsqueeze(0), spec_clean.unsqueeze(0), img_noisy.unsqueeze(0), spec_noisy.unsqueeze(0), mask.unsqueeze(0)

In [None]:
dataset= MRIDenoisingDataset(img,spec)
test_dataset = MRIDenoisingDataset(test_img,test_spec)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()

        #Encoder
        self.conv0 = nn.Conv2d(1, 48, 3, stride=1, padding=1)
        self.conv1 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.conv6 = nn.Conv2d(48, 48, 3, stride=1, padding=1)

        #Decoder
        self.upsample5 = nn.Upsample(scale_factor=2)
        self.deconv5a = nn.Conv2d(96, 96, 3, stride=1, padding=1)
        self.deconv5b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample4 = nn.Upsample(scale_factor=2)
        self.deconv4a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv4b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample3 = nn.Upsample(scale_factor=2)
        self.deconv3a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv3b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample2 = nn.Upsample(scale_factor=2)
        self.deconv2a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv2b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)
        self.deconv1a = nn.Conv2d(97, 64, 3, stride=1, padding=1)
        self.deconv1b = nn.Conv2d(64, 32, 3, stride=1, padding=1)
        self.deconv1c = nn.Conv2d(32, 1, 3, stride=1, padding=1)


        
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = F.leaky_relu(self.conv0(original_in))
        x = F.leaky_relu(self.conv1(x))
        x = self.pool1(x)
        pool1 = x
        x = F.leaky_relu(self.conv2(x))
        x = self.pool2(x)
        pool2 = x
        x = F.leaky_relu(self.conv3(x))
        x = self.pool3(x)
        pool3 = x
        x = F.leaky_relu(self.conv4(x))
        x = self.pool4(x)
        pool4 = x
        x = F.leaky_relu(self.conv5(x))
        x = self.pool5(x)
        x = F.leaky_relu(self.conv6(x))
        x = self.upsample5(x)
        x = torch.cat((x, pool4), 1)
        x = F.leaky_relu(self.deconv5a(x))
        x = F.leaky_relu(self.deconv5b(x))
        x = self.upsample4(x)
        x = torch.cat((x, pool3), 1)
        x = F.leaky_relu(self.deconv4a(x))
        x = F.leaky_relu(self.deconv4b(x))
        x = self.upsample3(x)
        x = torch.cat((x, pool2), 1)
        x = F.leaky_relu(self.deconv3a(x))
        x = F.leaky_relu(self.deconv3b(x))
        x = self.upsample2(x)
        x = torch.cat((x, pool1), 1)
        x = F.leaky_relu(self.deconv2a(x))
        x = F.leaky_relu(self.deconv2b(x))
        x = self.upsample1(x)
        x = torch.cat((x, original_in), 1)
        x = F.leaky_relu(self.deconv1a(x))
        x = F.leaky_relu(self.deconv1b(x))
        x = self.deconv1c(x)
        return x[:,:,:-1,:-1]

model = Unet().to(device)

In [None]:
summary(model, (1, 255, 255))

In [None]:
# the loss function
loss_fn = nn.MSELoss()
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)


In [None]:
def fftshift3d(x, ifft):
    assert len(x.shape) == 3
    s0 = (x.shape[1] // 2) + (0 if ifft else 1)
    s1 = (x.shape[2] // 2) + (0 if ifft else 1)
    x = torch.cat([x[:, s0:, :], x[:, :s0, :]], dim=1)
    x = torch.cat([x[:, :, s1:], x[:, :, :s1]], dim=2)
    return x

In [None]:
def train(net, trainLoader,testLoader, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                optimizer.zero_grad()
                outputs = net(img_noisy)
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss                          
                

In [None]:
train_loss,test_loss = train(model, dataloader, dataloader_test, 25)
plt.figure()
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
10*np.log10(1/0.0008)

In [None]:
plt.imshow(img[50],cmap='gray')

In [None]:
img_noisy, spec_noisy, mask = corrupt_data(img[50], spec[50])

In [None]:
img_noisy=np.clip(img_noisy, -.5, .5)

img_noisy = torch.tensor(img_noisy).to(device)
img_noisy = img_noisy.unsqueeze(0).unsqueeze(0)
reconstructed = model(img_noisy.to(device))
reconstructed=torch.clamp(reconstructed, -0.5, 0.5)
plt.imshow(np.squeeze(reconstructed.cpu().detach().numpy()),cmap='gray')

## Rician Noise

In [None]:
def rician_noise(img,noise_percent):
    sigma =(noise_percent/100)*img.max().item()
    noise1 = np.random.normal(0,sigma,img.shape)
    noise2 = np.random.normal(0,sigma,img.shape)
    noisy_img = np.sqrt((img+noise1)**2+noise2**2)
    return noisy_img

In [None]:
class RicianMRIDenoisingDataset(Dataset):
    def __init__(self,clean_images,t=64,corrupt_fn=rician_noise,augment_fn=None):
        super(RicianMRIDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]+0.5 #now img in[0,1] range
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg = self.corrupt_fn(img_clean,11)
        img_noisy = torch.tensor(cimg,dtype=torch.float32).clamp(0,1)
        img_clean = torch.tensor(img_clean,dtype=torch.float32).clamp(0,1)
        return img_clean.unsqueeze(0), img_noisy.unsqueeze(0)

In [None]:
Riciandataset= RicianMRIDenoisingDataset(img,spec)
Riciantest_dataset = RicianMRIDenoisingDataset(test_img,test_spec)
Riciandataloader = DataLoader(Riciandataset, batch_size=32, shuffle=True)
Riciandataloader_test = DataLoader(Riciantest_dataset, batch_size=32, shuffle=False)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class RicianUnet(nn.Module):
    def __init__(self):
        super(RicianUnet, self).__init__()

        #Encoder
        self.conv0 = nn.Conv2d(1, 48, 3, stride=1, padding=1)
        self.conv1 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.conv6 = nn.Conv2d(48, 48, 3, stride=1, padding=1)

        #Decoder
        self.upsample5 = nn.Upsample(scale_factor=2)
        self.deconv5a = nn.Conv2d(96, 96, 3, stride=1, padding=1)
        self.deconv5b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample4 = nn.Upsample(scale_factor=2)
        self.deconv4a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv4b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample3 = nn.Upsample(scale_factor=2)
        self.deconv3a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv3b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample2 = nn.Upsample(scale_factor=2)
        self.deconv2a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv2b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)
        self.deconv1a = nn.Conv2d(97, 64, 3, stride=1, padding=1)
        self.deconv1b = nn.Conv2d(64, 32, 3, stride=1, padding=1)
        self.deconv1c = nn.Conv2d(32, 1, 3, stride=1, padding=1)


        
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = F.leaky_relu(self.conv0(original_in))
        x = F.leaky_relu(self.conv1(x))
        x = self.pool1(x)
        pool1 = x
        x = F.leaky_relu(self.conv2(x))
        x = self.pool2(x)
        pool2 = x
        x = F.leaky_relu(self.conv3(x))
        x = self.pool3(x)
        pool3 = x
        x = F.leaky_relu(self.conv4(x))
        x = self.pool4(x)
        pool4 = x
        x = F.leaky_relu(self.conv5(x))
        x = self.pool5(x)
        x = F.leaky_relu(self.conv6(x))
        x = self.upsample5(x)
        x = torch.cat((x, pool4), 1)
        x = F.leaky_relu(self.deconv5a(x))
        x = F.leaky_relu(self.deconv5b(x))
        x = self.upsample4(x)
        x = torch.cat((x, pool3), 1)
        x = F.leaky_relu(self.deconv4a(x))
        x = F.leaky_relu(self.deconv4b(x))
        x = self.upsample3(x)
        x = torch.cat((x, pool2), 1)
        x = F.leaky_relu(self.deconv3a(x))
        x = F.leaky_relu(self.deconv3b(x))
        x = self.upsample2(x)
        x = torch.cat((x, pool1), 1)
        x = F.leaky_relu(self.deconv2a(x))
        x = F.leaky_relu(self.deconv2b(x))
        x = self.upsample1(x)
        x = torch.cat((x, original_in), 1)
        x = F.leaky_relu(self.deconv1a(x))
        x = F.leaky_relu(self.deconv1b(x))
        x = self.deconv1c(x)
        return x[:,:,:-1,:-1]

rician_model = RicianUnet().to(device)

In [None]:
# the loss function
rician_loss_fn = nn.MSELoss()
# the optimizer
rician_optimizer = optim.Adam(rician_model.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)

In [None]:
def ricitrain(net, trainLoader,testLoader, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean, img_noisy= data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                rician_optimizer.zero_grad()
                outputs = net(img_noisy)
                loss = rician_loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                rician_optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,img_noisy = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -0.5, 0.5)
                    loss = rician_loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
    
    return train_loss,test_loss                          
                

In [None]:
train_loss,test_loss = ricitrain(rician_model, Riciandataloader, Riciandataloader_test, 25)
plt.figure()
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
plt.figure()
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.legend()
plt.ylabel('Loss')