In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchsummary import summary
from torch.optim.lr_scheduler import CosineAnnealingLR

import matplotlib.pyplot as plt
import numpy as np
import os
import time
from tqdm import tqdm

import nibabel as nib
from scipy.fftpack import fft, ifft, fft2, ifft2, fftshift, ifftshift
import PIL.Image
import pickle



# Prepare the data

In [None]:
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)
def fftshift2d(x, ifft=False):
    assert (len(x.shape) == 2) and all([(s % 2 == 1) for s in x.shape])
    s0 = (x.shape[0] // 2) + (0 if ifft else 1)
    s1 = (x.shape[1] // 2) + (0 if ifft else 1)
    x = np.concatenate([x[s0:, :], x[:s0, :]], axis=0)
    x = np.concatenate([x[:, s1:], x[:, :s1]], axis=1)
    return x

In [None]:
# ruta = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_train-001.pkl"
# ruta_test = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_valid.pkl"
ruta = "C:/Users/javit/Desktop/N2N/datasets/ixi_train.pkl"
ruta_test = "C:/Users/javit/Desktop/N2N/datasets/ixi_valid.pkl"
img, spec = load_pkl(ruta)
img=img[:,:-1,:-1] #images are now 255,255
img = img.astype(np.float32) / 255.0 - 0.5 # normalize and make sure they are in range [-.5,.5]
test_img, test_spec = load_pkl(ruta_test)
test_img=test_img[:,:-1,:-1]
test_img=test_img.astype(np.float32) / 255.0 - 0.5


p_at_edge=0.025
h = [s // 2 for s in (255,255)] #255
r = [np.arange(s, dtype=np.float32) - h for s, h in zip((255,255), h)]
r = [x ** 2 for x in r]
r = (r[0][:, np.newaxis] + r[1][np.newaxis, :]) ** .5
m = (p_at_edge ** (1./h[1])) ** r
bern_mask = m

In [None]:
def corrupt_data(img, spec):
    global bern_mask
    mask = bern_mask
    # print('Bernoulli probability at edge = %.5f' % mask[h[0], 0])
    # print('Average Bernoulli probability = %.5f' % np.mean(mask))
    keep = (np.random.uniform(0.0, 1.0, size=spec.shape)**2 < mask)
    keep = keep & keep[::-1, ::-1]
    sval = spec * keep
    smsk = keep.astype(np.float32)
    spec = fftshift2d(sval / (mask + ~keep), ifft=True) # Add 1.0 to not-kept values to prevent div-by-zero.
    img = np.real(np.fft.ifft2(spec)).astype(np.float32)
    return img, sval, smsk

In [None]:
class MRIDenoisingDataset(Dataset):
    def __init__(self,clean_images,clean_specs,t=64,corrupt_fn=corrupt_data,augment_fn=None):
        super(MRIDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.clean_specs = clean_specs
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]
        spec_clean = self.clean_specs[idx]
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg, cspec, cmask = self.corrupt_fn(img_clean,spec_clean)
        img_noisy = torch.tensor(cimg,dtype=torch.float32)
        spec_noisy = torch.tensor(cspec,dtype=torch.complex64)
        mask = torch.tensor(cmask,dtype=torch.float32)
        img_clean = torch.tensor(img_clean,dtype=torch.float32)
        spec_clean = torch.tensor(spec_clean,dtype=torch.complex64)
        return img_clean.unsqueeze(0), spec_clean.unsqueeze(0), img_noisy.unsqueeze(0), spec_noisy.unsqueeze(0), mask.unsqueeze(0)

In [None]:
dataset= MRIDenoisingDataset(img,spec)
test_dataset = MRIDenoisingDataset(test_img,test_spec)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class AutorEncoder(nn.Module):
    def __init__(self):
        super(AutorEncoder, self).__init__()

        #Initial convolutions
        self.conv1 = nn.Conv2d(1,64,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,padding=1)

        # Encoder layers
        self.enc1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)
        self.enc2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)

        #Residual blocks
        self.res1conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.res1conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.res2conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.res2conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.res3conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.res3conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        self.res4conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.res4conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(128)
        
        # Decoder layers
        self.dec1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,padding=1,output_padding=1)
        self.dec2 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2,padding=1)

        #Final convolution
        self.out = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        
        
    def forward(self, x):
        original_noisy = x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #Encoder
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        #Residual blocks
        res1 = x
        x = self.res1conv1(x)
        x = F.relu(self.bn1(x))
        x = self.res1conv2(x)
        x = self.bn2(x)
        x += res1

        res2 = x
        x = self.res2conv1(x)
        x = F.relu(self.bn3(x))
        x = self.res2conv2(x)
        x = self.bn4(x)
        x += res2

        res3 = x
        x = self.res3conv1(x)
        x = F.relu(self.bn5(x))
        x = self.res3conv2(x)
        x = self.bn6(x)
        x += res3

        res4 = x
        x = self.res4conv1(x)
        x = F.relu(self.bn7(x))
        x = self.res4conv2(x)
        x = self.bn8(x)
        x += res4
        # Decoder
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        #Final convolution
        x = self.out(x)
        x += original_noisy
        return x

model = AutorEncoder().to(device)

In [None]:
print(model)

In [None]:
summary(model, (1, 255, 255))

In [None]:
# the loss function
loss_fn = nn.MSELoss()
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)


In [None]:
def fftshift3d(x, ifft):
    assert len(x.shape) == 3
    s0 = (x.shape[1] // 2) + (0 if ifft else 1)
    s1 = (x.shape[2] // 2) + (0 if ifft else 1)
    x = torch.cat([x[:, s0:, :], x[:, :s0, :]], dim=1)
    x = torch.cat([x[:, :, s1:], x[:, :, :s1]], dim=2)
    return x

In [None]:
def train(net, trainLoader,testLoader, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                optimizer.zero_grad()
                outputs = net(img_noisy)
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss                          
                

In [None]:
train_loss,test_loss = train(model, dataloader, dataloader_test, 25)
plt.figure()
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
10*np.log10(1/ 0.0013)

In [None]:
plt.imshow(img[50],cmap='gray')

In [None]:
img_noisy, spec_noisy, mask = corrupt_data(img[50], spec[50])

In [None]:
plt.imshow(np.clip(img_noisy, -.5, .5),cmap='gray')

In [None]:
img_noisy=np.clip(img_noisy, -.5, .5)

In [None]:

img_noisy = torch.tensor(img_noisy).to(device)
img_noisy = img_noisy.unsqueeze(0).unsqueeze(0)

In [None]:
reconstructed = model(img_noisy.to(device))
reconstructed=torch.clamp(reconstructed, -0.5, 0.5)

In [None]:
plt.imshow(np.squeeze(reconstructed.cpu().detach().numpy()),cmap='gray')

In [None]:
np.mean((img[50] - np.squeeze(reconstructed.cpu().detach().numpy()))**2)

## Let's try rician noise

In [None]:
def rician_noise(img,noise_percent):
    sigma =(noise_percent/100)*img.max().item()
    noise1 = np.random.normal(0,sigma,img.shape)
    noise2 = np.random.normal(0,sigma,img.shape)
    noisy_img = np.sqrt((img+noise1)**2+noise2**2)
    return noisy_img

In [None]:
class RicianMRIDenoisingDataset(Dataset):
    def __init__(self,clean_images,t=64,corrupt_fn=rician_noise,augment_fn=None):
        super(RicianMRIDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]+0.5 #now img in[0,1] range
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg = self.corrupt_fn(img_clean,11)
        img_noisy = torch.tensor(cimg,dtype=torch.float32).clamp(0,1)
        img_clean = torch.tensor(img_clean,dtype=torch.float32).clamp(0,1)
        return img_clean.unsqueeze(0), img_noisy.unsqueeze(0)

In [None]:
Riciandataset= RicianMRIDenoisingDataset(img,spec)
Riciantest_dataset = RicianMRIDenoisingDataset(test_img,test_spec)
Riciandataloader = DataLoader(Riciandataset, batch_size=32, shuffle=True)
Riciandataloader_test = DataLoader(Riciantest_dataset, batch_size=32, shuffle=False)

In [None]:
noisy_im = rician_noise(img[100]+0.5,11)
plt.imshow(np.clip(noisy_im,0,1),cmap='gray')

In [None]:
mse=np.mean(img[100]+0.5-np.clip(noisy_im,0,1))**2
psnr=10*np.log10(1/mse)
psnr

In [None]:
plt.imshow(img[200]+0.5,cmap='gray')

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class RicciAutoEncoder(nn.Module):
    def __init__(self):
        super(RicciAutoEncoder, self).__init__()

        #Initial convolutions
        self.conv1 = nn.Conv2d(1,64,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,padding=1)

        # Encoder layers
        self.enc1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)
        self.enc2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)

        #Residual blocks
        self.res1conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.res1conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.res2conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.res2conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.res3conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.res3conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        self.res4conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.res4conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(128)
        
        # Decoder layers
        self.dec1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,padding=1,output_padding=1)
        self.dec2 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2,padding=1)

        #Final convolution
        self.out = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        
        
    def forward(self, x):
        original_noisy = x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #Encoder
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        #Residual blocks
        res1 = x
        x = self.res1conv1(x)
        x = F.relu(self.bn1(x))
        x = self.res1conv2(x)
        x = self.bn2(x)
        x += res1

        res2 = x
        x = self.res2conv1(x)
        x = F.relu(self.bn3(x))
        x = self.res2conv2(x)
        x = self.bn4(x)
        x += res2

        res3 = x
        x = self.res3conv1(x)
        x = F.relu(self.bn5(x))
        x = self.res3conv2(x)
        x = self.bn6(x)
        x += res3

        res4 = x
        x = self.res4conv1(x)
        x = F.relu(self.bn7(x))
        x = self.res4conv2(x)
        x = self.bn8(x)
        x += res4
        # Decoder
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        #Final convolution
        x = self.out(x)
        x += original_noisy
        return x

Riccimodel = RicciAutoEncoder().to(device)

In [None]:
# the loss function
Riciloss_fn = nn.MSELoss()
# the optimizer
Ricioptimizer = optim.Adam(Riccimodel.parameters(), lr=0.001)
Ricischeduler = CosineAnnealingLR(Ricioptimizer, T_max=25, eta_min=0)

In [None]:
def Riccitrain(net, trainLoader,testLoader, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean, img_noisy= data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                Ricioptimizer.zero_grad()
                outputs = net(img_noisy)
                loss = Riciloss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                Ricioptimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        Ricischeduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,img_noisy = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -0.5, 0.5)
                    loss = Riciloss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
    
    return train_loss,test_loss                          
                

In [None]:
train_loss,test_loss = Riccitrain(Riccimodel, Riciandataloader, Riciandataloader_test, 25)
plt.figure()
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
plt.figure()
plt.plot(train_loss, label='train loss')
plt.plot(test_loss, label='test loss')
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

In [None]:
noisy = rician_noise(img[150]+0.5,11)
reconstructed = Riccimodel(torch.tensor(noisy,dtype=torch.float).unsqueeze(0).unsqueeze(0).to(device))
reconstructed=torch.clamp(reconstructed, 0, 1)
plt.imshow(np.squeeze(reconstructed.cpu().detach().numpy()),cmap='gray')
plt.show()


In [None]:
plt.imshow(noisy,cmap='gray')