In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np
import os
import time
from tqdm import tqdm
from kymatio.torch import Scattering2D
from torch.optim.lr_scheduler import CosineAnnealingLR

import nibabel as nib
from scipy.fftpack import fft, ifft, fft2, ifft2, fftshift, ifftshift
import PIL.Image
import pickle

import csv

In [None]:
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)
def fftshift2d(x, ifft=False):
    assert (len(x.shape) == 2) and all([(s % 2 == 1) for s in x.shape])
    s0 = (x.shape[0] // 2) + (0 if ifft else 1)
    s1 = (x.shape[1] // 2) + (0 if ifft else 1)
    x = np.concatenate([x[s0:, :], x[:s0, :]], axis=0)
    x = np.concatenate([x[:, s1:], x[:, :s1]], axis=1)
    return x

In [None]:
# ruta = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_train-001.pkl"
# ruta_test = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_valid.pkl"
ruta = "C:/Users/javit/Desktop/N2N/datasets/ixi_train.pkl"
ruta_test = "C:/Users/javit/Desktop/N2N/datasets/ixi_valid.pkl"
img, spec = load_pkl(ruta)
img=img[:,:-1,:-1] #images are now 255,255
img = img.astype(np.float32) / 255.0 - 0.5 # normalize and make sure they are in range [-.5,.5]
test_img, test_spec = load_pkl(ruta_test)
test_img=test_img[:,:-1,:-1]
test_img=test_img.astype(np.float32) / 255.0 - 0.5


p_at_edge=0.025
h = [s // 2 for s in (255,255)] #255
r = [np.arange(s, dtype=np.float32) - h for s, h in zip((255,255), h)]
r = [x ** 2 for x in r]
r = (r[0][:, np.newaxis] + r[1][np.newaxis, :]) ** .5
m = (p_at_edge ** (1./h[1])) ** r
bern_mask = m

In [None]:
def corrupt_data(img, spec):
    global bern_mask
    mask = bern_mask
    # print('Bernoulli probability at edge = %.5f' % mask[h[0], 0])
    # print('Average Bernoulli probability = %.5f' % np.mean(mask))
    keep = (np.random.uniform(0.0, 1.0, size=spec.shape)**2 < mask)
    keep = keep & keep[::-1, ::-1]
    sval = spec * keep
    smsk = keep.astype(np.float32)
    spec = fftshift2d(sval / (mask + ~keep), ifft=True) # Add 1.0 to not-kept values to prevent div-by-zero.
    img = np.real(np.fft.ifft2(spec)).astype(np.float32)
    return img, sval, smsk

In [None]:
class MRIDenoisingDataset(Dataset):
    def __init__(self,clean_images,clean_specs,t=64,corrupt_fn=corrupt_data,augment_fn=None):
        super(MRIDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.clean_specs = clean_specs
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]
        spec_clean = self.clean_specs[idx]
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg, cspec, cmask = self.corrupt_fn(img_clean,spec_clean)
        img_noisy = torch.tensor(cimg,dtype=torch.float32)
        spec_noisy = torch.tensor(cspec,dtype=torch.complex64)
        mask = torch.tensor(cmask,dtype=torch.float32)
        img_clean = torch.tensor(img_clean,dtype=torch.float32)
        spec_clean = torch.tensor(spec_clean,dtype=torch.complex64)
        return img_clean.unsqueeze(0), spec_clean.unsqueeze(0), img_noisy.unsqueeze(0), spec_noisy.unsqueeze(0), mask.unsqueeze(0)

In [None]:
dataset= MRIDenoisingDataset(img,spec)
test_dataset = MRIDenoisingDataset(test_img,test_spec)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=32, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Wavelet Scattering Transform autoencoder

In [None]:
class WSTAutoencoder(nn.Module):
    def __init__(self):
        super(WSTAutoencoder, self).__init__()
        self.scattering2d = Scattering2D(J=2, L=8, shape=(256,256)) # [81,64,64]

        #Decoder
        self.conv0 = nn.Conv2d(81,64,3,stride=1,padding=1)
        self.bn0 = nn.BatchNorm2d(64)

        self.res1conv1 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res1bn1 = nn.BatchNorm2d(64)
        self.res1conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res1bn2 = nn.BatchNorm2d(64)

        self.res2conv1 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res2bn1 = nn.BatchNorm2d(64)
        self.res2conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res2bn2 = nn.BatchNorm2d(64)

        self.upsample1 = nn.Upsample(scale_factor=2) #[64,128,128]
        self.conv1 = nn.Conv2d(64,32,3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.res3conv1 = nn.Conv2d(32,32,3,stride=1,padding=1)
        self.res3bn1 = nn.BatchNorm2d(32)
        self.res3conv2 = nn.Conv2d(32,32,3,stride=1,padding=1)
        self.res3bn2 = nn.BatchNorm2d(32)

        self.upsample2 = nn.Upsample(scale_factor=2) #[32,256,256]
        self.conv2 = nn.Conv2d(33,16,3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(16)

        self.res4conv1 = nn.Conv2d(16,16,3,stride=1,padding=1)
        self.res4bn1 = nn.BatchNorm2d(16)
        self.res4conv2 = nn.Conv2d(16,16,3,stride=1,padding=1)
        self.res4bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16,9,3,stride=1,padding=1)
        self.bn3 = nn.BatchNorm2d(9)

        self.res5conv1 = nn.Conv2d(9,9,3,stride=1,padding=1)
        self.res5bn1 = nn.BatchNorm2d(9)
        self.res5conv2 = nn.Conv2d(9,9,3,stride=1,padding=1)
        self.res5bn2 = nn.BatchNorm2d(9)

        self.conv4 = nn.Conv2d(9,1,3,stride=1,padding=1)
        # self.bn4 = nn.BatchNorm2d(1)
    
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = self.scattering2d(original_in).squeeze(1)
        
        x = self.conv0(x)
        x = F.leaky_relu(self.bn0(x))

        resblock1_input =x
        x = self.res1conv1(x)
        x = F.leaky_relu(self.res1bn1(x))

        x = self.res1conv2(x)
        x = F.leaky_relu(self.res1bn2(x))

        x = x + resblock1_input
        resblock2_input = x
        x = self.res2conv1(x)

        x = self.res2conv2(x)
        x = F.leaky_relu(self.res2bn2(x))

        x = x + resblock2_input
        x = self.upsample1(x)
        x = self.conv1(x)
        x = F.leaky_relu(self.bn1(x))

        resblock3_input = x
        x = self.res3conv1(x)
        x = F.leaky_relu(self.res3bn1(x))

        x = self.res3conv2(x)
        x = F.leaky_relu(self.res3bn2(x))
        x = x + resblock3_input

        x = self.upsample2(x)
        x = torch.cat((x, original_in), dim=1)
        x = self.conv2(x)
        x = F.leaky_relu(self.bn2(x))

        resblock4_input = x
        x = self.res4conv1(x)
        x = F.leaky_relu(self.res4bn1(x))
        x = self.res4conv2(x)
        x = F.leaky_relu(self.res4bn2(x))
        x = x + resblock4_input
        
        x = self.conv3(x)
        x = F.leaky_relu(self.bn3(x))

        resblock5_input = x
        x = self.res5conv1(x)
        x = F.leaky_relu(self.res5bn1(x))
        x = self.res5conv2(x)
        x = F.leaky_relu(self.res5bn2(x))
        x = x + resblock5_input
        x = self.conv4(x)
        # x = F.sigmoid(x)-0.5
        return x[:,:,:-1,:-1]
model= WSTAutoencoder().to(device)

In [None]:
summary(model, (1, 255, 255))

In [None]:
loss_fn = nn.MSELoss()
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
def train(net, trainLoader,testLoader, loss_fn, optimizer, scheduler, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                optimizer.zero_grad()
                outputs = net(img_noisy)
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss         

In [None]:
binary_wsta_train_loss,binary_wsta_test_loss = train(model, dataloader, dataloader_test, loss_fn, optimizer, scheduler, 25)
plt.figure()
plt.plot(binary_wsta_train_loss, label='train loss')
plt.plot(binary_wsta_test_loss, label='test loss')
plt.legend()
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
with open('./Results/train_losses_binary_wsta_with_eval.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Train Loss'])  # encabezado opcional
    for i, loss in enumerate(binary_wsta_train_loss, start=1):
        writer.writerow([i, loss])