In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm
from kymatio.torch import Scattering2D
from torch.optim.lr_scheduler import CosineAnnealingLR


import pickle

import csv

In [None]:
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)
def fftshift2d(x, ifft=False):
    assert (len(x.shape) == 2) and all([(s % 2 == 1) for s in x.shape])
    s0 = (x.shape[0] // 2) + (0 if ifft else 1)
    s1 = (x.shape[1] // 2) + (0 if ifft else 1)
    x = np.concatenate([x[s0:, :], x[:s0, :]], axis=0)
    x = np.concatenate([x[:, s1:], x[:, :s1]], axis=1)
    return x
def fftshift3d(x, ifft):
    assert len(x.shape) == 4
    s0 = (x.shape[2] // 2) + (0 if ifft else 1)
    s1 = (x.shape[3] // 2) + (0 if ifft else 1)
    x = torch.cat([x[:,:, s0:, :], x[:,:, :s0, :]], dim=2)
    x = torch.cat([x[:,:, :, s1:], x[:,:, :, :s1]], dim=3)
    return x

In [None]:
# ruta = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_train-001.pkl"
# ruta_test = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_valid.pkl"
ruta = "C:/Users/javit/Desktop/N2N/datasets/ixi_train.pkl"
ruta_test = "C:/Users/javit/Desktop/N2N/datasets/ixi_valid.pkl"
img, spec = load_pkl(ruta)
img=img[:,:-1,:-1] #images are now 255,255
img = img.astype(np.float32) / 255.0 - 0.5 # normalize and make sure they are in range [-.5,.5]
test_img, test_spec = load_pkl(ruta_test)
test_img=test_img[:,:-1,:-1]
test_img=test_img.astype(np.float32) / 255.0 - 0.5

In [None]:
p_at_edge=0.025
h = [s // 2 for s in (255,255)] #255
r = [np.arange(s, dtype=np.float32) - h for s, h in zip((255,255), h)]
r = [x ** 2 for x in r]
r = (r[0][:, np.newaxis] + r[1][np.newaxis, :]) ** .5
m = (p_at_edge ** (1./h[1])) ** r
bern_mask = m

def corrupt_data_binary(img, spec):
    global bern_mask
    mask = bern_mask
    # print('Bernoulli probability at edge = %.5f' % mask[h[0], 0])
    # print('Average Bernoulli probability = %.5f' % np.mean(mask))
    keep = (np.random.uniform(0.0, 1.0, size=spec.shape)**2 < mask)
    keep = keep & keep[::-1, ::-1]
    sval = spec * keep
    smsk = keep.astype(np.float32)
    spec = fftshift2d(sval / (mask + ~keep), ifft=True) # Add 1.0 to not-kept values to prevent div-by-zero.
    img = np.real(np.fft.ifft2(spec)).astype(np.float32)
    return img, sval, smsk

class BinaryDenoisingDataset(Dataset):
    def __init__(self,clean_images,clean_specs,corrupt_fn=corrupt_data_binary):
        super(BinaryDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.clean_specs = clean_specs
        self.corrupt_fn = corrupt_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]
        spec_clean = self.clean_specs[idx]
        # Corrupt data
        cimg, cspec, cmask = self.corrupt_fn(img_clean,spec_clean)
        img_noisy = torch.tensor(cimg,dtype=torch.float32)
        spec_noisy = torch.tensor(cspec,dtype=torch.complex64)
        mask = torch.tensor(cmask,dtype=torch.float32)
        img_clean = torch.tensor(img_clean,dtype=torch.float32)
        spec_clean = torch.tensor(spec_clean,dtype=torch.complex64)
        return img_clean.unsqueeze(0), spec_clean.unsqueeze(0), img_noisy.unsqueeze(0), spec_noisy.unsqueeze(0), mask.unsqueeze(0)

In [None]:
def corrupt_data_rician(img,noise_percent):
    sigma =(noise_percent/100)*img.max().item()
    noise1 = np.random.normal(0,sigma,img.shape)
    noise2 = np.random.normal(0,sigma,img.shape)
    noisy_img = np.sqrt((img+noise1)**2+noise2**2)
    return noisy_img
class RicianDenoisingDataset(Dataset):
    def __init__(self,clean_images,t=64,corrupt_fn=corrupt_data_rician,augment_fn=None):
        super(RicianDenoisingDataset, self).__init__()
        self.clean_images = clean_images
        self.t = t
        self.corrupt_fn = corrupt_fn
        self.augment_fn = augment_fn
    def __len__(self):
        return len(self.clean_images)
    def __getitem__(self, idx):
        img_clean= self.clean_images[idx]+0.5 #now img in[0,1] range
        # Data augmentation
        if self.augment_fn:
            img_clean,spec_clean = self.augment_fn(img_clean,spec_clean,t=self.t)        
        # Corrupt data
        cimg = self.corrupt_fn(img_clean,11)
        img_noisy = torch.tensor(cimg,dtype=torch.float32).clamp(0,1)
        img_clean = torch.tensor(img_clean,dtype=torch.float32).clamp(0,1)
        return img_clean.unsqueeze(0)-0.5, img_noisy.unsqueeze(0)-0.5 #Again in [-.5,.5] range

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Unet Model

In [None]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()

        #Encoder
        self.conv0 = nn.Conv2d(1, 48, 3, stride=1, padding=1)
        self.conv1 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(48, 48, 3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.conv6 = nn.Conv2d(48, 48, 3, stride=1, padding=1)

        #Decoder
        self.upsample5 = nn.Upsample(scale_factor=2)
        self.deconv5a = nn.Conv2d(96, 96, 3, stride=1, padding=1)
        self.deconv5b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample4 = nn.Upsample(scale_factor=2)
        self.deconv4a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv4b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample3 = nn.Upsample(scale_factor=2)
        self.deconv3a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv3b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample2 = nn.Upsample(scale_factor=2)
        self.deconv2a = nn.Conv2d(144, 96, 3, stride=1, padding=1)
        self.deconv2b = nn.Conv2d(96, 96, 3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)
        self.deconv1a = nn.Conv2d(97, 64, 3, stride=1, padding=1)
        self.deconv1b = nn.Conv2d(64, 32, 3, stride=1, padding=1)
        self.deconv1c = nn.Conv2d(32, 1, 3, stride=1, padding=1)


        
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = F.leaky_relu(self.conv0(original_in))
        x = F.leaky_relu(self.conv1(x))
        x = self.pool1(x)
        pool1 = x
        x = F.leaky_relu(self.conv2(x))
        x = self.pool2(x)
        pool2 = x
        x = F.leaky_relu(self.conv3(x))
        x = self.pool3(x)
        pool3 = x
        x = F.leaky_relu(self.conv4(x))
        x = self.pool4(x)
        pool4 = x
        x = F.leaky_relu(self.conv5(x))
        x = self.pool5(x)
        x = F.leaky_relu(self.conv6(x))
        x = self.upsample5(x)
        x = torch.cat((x, pool4), 1)
        x = F.leaky_relu(self.deconv5a(x))
        x = F.leaky_relu(self.deconv5b(x))
        x = self.upsample4(x)
        x = torch.cat((x, pool3), 1)
        x = F.leaky_relu(self.deconv4a(x))
        x = F.leaky_relu(self.deconv4b(x))
        x = self.upsample3(x)
        x = torch.cat((x, pool2), 1)
        x = F.leaky_relu(self.deconv3a(x))
        x = F.leaky_relu(self.deconv3b(x))
        x = self.upsample2(x)
        x = torch.cat((x, pool1), 1)
        x = F.leaky_relu(self.deconv2a(x))
        x = F.leaky_relu(self.deconv2b(x))
        x = self.upsample1(x)
        x = torch.cat((x, original_in), 1)
        x = F.leaky_relu(self.deconv1a(x))
        x = F.leaky_relu(self.deconv1b(x))
        x = self.deconv1c(x)
        return x[:,:,:-1,:-1]
    

class CNNDMRI(nn.Module):
    def __init__(self):
        super(CNNDMRI, self).__init__()

        #Initial convolutions
        self.conv1 = nn.Conv2d(1,64,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,padding=1)

        # Encoder layers
        self.enc1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)
        self.enc2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)

        #Residual blocks
        self.res1conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.res1conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.res2conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.res2conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.res3conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.res3conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        self.res4conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.res4conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(128)
        
        # Decoder layers
        self.dec1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,padding=1,output_padding=1)
        self.dec2 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2,padding=1)

        #Final convolution
        self.out = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        
        
    def forward(self, x):
        original_noisy = x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #Encoder
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        #Residual blocks
        res1 = x
        x = self.res1conv1(x)
        x = F.relu(self.bn1(x))
        x = self.res1conv2(x)
        x = self.bn2(x)
        x += res1

        res2 = x
        x = self.res2conv1(x)
        x = F.relu(self.bn3(x))
        x = self.res2conv2(x)
        x = self.bn4(x)
        x += res2

        res3 = x
        x = self.res3conv1(x)
        x = F.relu(self.bn5(x))
        x = self.res3conv2(x)
        x = self.bn6(x)
        x += res3

        res4 = x
        x = self.res4conv1(x)
        x = F.relu(self.bn7(x))
        x = self.res4conv2(x)
        x = self.bn8(x)
        x += res4
        # Decoder
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        #Final convolution
        x = self.out(x)
        x += original_noisy
        return x
    
class WSTAutoencoder(nn.Module):
    def __init__(self):
        super(WSTAutoencoder, self).__init__()
        self.scattering2d = Scattering2D(J=2, L=8, shape=(256,256)) # [81,64,64]

        #Decoder
        self.conv0 = nn.Conv2d(81,64,3,stride=1,padding=1)
        self.bn0 = nn.BatchNorm2d(64)

        self.res1conv1 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res1bn1 = nn.BatchNorm2d(64)
        self.res1conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res1bn2 = nn.BatchNorm2d(64)

        self.res2conv1 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res2bn1 = nn.BatchNorm2d(64)
        self.res2conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.res2bn2 = nn.BatchNorm2d(64)

        self.upsample1 = nn.Upsample(scale_factor=2) #[64,128,128]
        self.conv1 = nn.Conv2d(64,32,3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.res3conv1 = nn.Conv2d(32,32,3,stride=1,padding=1)
        self.res3bn1 = nn.BatchNorm2d(32)
        self.res3conv2 = nn.Conv2d(32,32,3,stride=1,padding=1)
        self.res3bn2 = nn.BatchNorm2d(32)

        self.upsample2 = nn.Upsample(scale_factor=2) #[32,256,256]
        self.conv2 = nn.Conv2d(33,16,3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(16)

        self.res4conv1 = nn.Conv2d(16,16,3,stride=1,padding=1)
        self.res4bn1 = nn.BatchNorm2d(16)
        self.res4conv2 = nn.Conv2d(16,16,3,stride=1,padding=1)
        self.res4bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16,9,3,stride=1,padding=1)
        self.bn3 = nn.BatchNorm2d(9)

        self.res5conv1 = nn.Conv2d(9,9,3,stride=1,padding=1)
        self.res5bn1 = nn.BatchNorm2d(9)
        self.res5conv2 = nn.Conv2d(9,9,3,stride=1,padding=1)
        self.res5bn2 = nn.BatchNorm2d(9)

        self.conv4 = nn.Conv2d(9,1,3,stride=1,padding=1)
        # self.bn4 = nn.BatchNorm2d(1)
    
    def forward(self, x):
        original_in = F.pad(x, (0,1,0,1), mode='constant', value=-0.5)
        x = self.scattering2d(original_in).squeeze(1)
        
        x = self.conv0(x)
        x = F.leaky_relu(self.bn0(x))

        resblock1_input =x
        x = self.res1conv1(x)
        x = F.leaky_relu(self.res1bn1(x))

        x = self.res1conv2(x)
        x = F.leaky_relu(self.res1bn2(x))

        x = x + resblock1_input
        resblock2_input = x
        x = self.res2conv1(x)

        x = self.res2conv2(x)
        x = F.leaky_relu(self.res2bn2(x))

        x = x + resblock2_input
        x = self.upsample1(x)
        x = self.conv1(x)
        x = F.leaky_relu(self.bn1(x))

        resblock3_input = x
        x = self.res3conv1(x)
        x = F.leaky_relu(self.res3bn1(x))

        x = self.res3conv2(x)
        x = F.leaky_relu(self.res3bn2(x))
        x = x + resblock3_input

        x = self.upsample2(x)
        x = torch.cat((x, original_in), dim=1)
        x = self.conv2(x)
        x = F.leaky_relu(self.bn2(x))

        resblock4_input = x
        x = self.res4conv1(x)
        x = F.leaky_relu(self.res4bn1(x))
        x = self.res4conv2(x)
        x = F.leaky_relu(self.res4bn2(x))
        x = x + resblock4_input
        
        x = self.conv3(x)
        x = F.leaky_relu(self.bn3(x))

        resblock5_input = x
        x = self.res5conv1(x)
        x = F.leaky_relu(self.res5bn1(x))
        x = self.res5conv2(x)
        x = F.leaky_relu(self.res5bn2(x))
        x = x + resblock5_input
        x = self.conv4(x)
        # x = F.sigmoid(x)-0.5
        return x[:,:,:-1,:-1]

In [None]:
binary_dataset= BinaryDenoisingDataset(img,spec)
binary_test_dataset = BinaryDenoisingDataset(test_img,test_spec)
binary_dataloader = DataLoader(binary_dataset, batch_size=32, shuffle=True)
binary_dataloader_test = DataLoader(binary_test_dataset, batch_size=32, shuffle=False)

In [None]:
def binary_train(net, trainLoader,testLoader, loss_fn, optimizer, scheduler, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                spec_noisy = spec_noisy.to(device)
                mask = mask.to(device)
                optimizer.zero_grad()
                denoised = net(img_noisy)
                denoised_spec = torch.fft.fft2(denoised)
                denoised_spec = fftshift3d(denoised_spec, ifft=False)
                spec_mask = mask.type(torch.complex64)
                denoised_spec = spec_noisy * spec_mask + denoised_spec * (1 - spec_mask)
                outputs = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True)))
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean,spec_clean, img_noisy, spec_noisy, mask = data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    spec_noisy = spec_noisy.to(device)
                    mask = mask.to(device)
                    denoised = net(img_noisy)
                    denoised_spec = torch.fft.fft2(denoised)
                    denoised_spec = fftshift3d(denoised_spec, ifft=False)
                    spec_mask = mask.type(torch.complex64)
                    denoised_spec = spec_noisy * spec_mask + denoised_spec * (1 - spec_mask)
                    outputs = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True)))
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss                          
                

In [None]:
binaryunet = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binaryunet.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_unet_train_loss,binary_unet_test_loss = binary_train(binaryunet, binary_dataloader, binary_dataloader_test, loss_fn, optimizer, scheduler, 25)

# corrupted_image, corr_sval, corr_msk = corrupt_data_binary(img[50],spec[50])
# plt.imshow(corrupted_image, cmap='gray')
# plt.show()
# plt.imshow(img[50], cmap='gray')
# denoised = binaryunet(corrupted_image)
# denoised_spec = torch.fft.fft2(denoised)
# denoised_spec = fftshift3d(denoised_spec, ifft=False)
# spec_mask = corr_msk.type(torch.complex64)
# denoised_spec = corr_sval * spec_mask + denoised_spec * (1 - spec_mask)
# output = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True))).clamp(-.5, .5)
# plt.imshow(output, cmap='gray')
# plt.show()


In [None]:
torch.save(binaryunet.state_dict(), './models/binary_unet_model.pth')

In [None]:
binarycnndmri = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarycnndmri.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_cnndmri_train_loss,binary_cnndmri_test_loss = binary_train(binarycnndmri, binary_dataloader, binary_dataloader_test, loss_fn, optimizer, scheduler, 25)

In [None]:
torch.save(binarycnndmri.state_dict(), './models/binary_cnndmri_model.pth')

In [None]:
binarywsta = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(binarywsta.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
binary_wsta_train_loss,binary_wsta_test_loss = binary_train(binarywsta, binary_dataloader, binary_dataloader_test, loss_fn, optimizer, scheduler, 25)

# corrupted_image, corr_sval, corr_msk = corrupt_data_binary(img[50],spec[50])


In [None]:
torch.save(binarywsta.state_dict(), './models/binary_wsta_model.pth')

In [None]:
rici_dataset = RicianDenoisingDataset(img,spec)
rici_test_dataset = RicianDenoisingDataset(test_img,test_spec)
rici_dataloader = DataLoader(rici_dataset, batch_size=32, shuffle=True)
rici_dataloader_test = DataLoader(rici_test_dataset, batch_size=32, shuffle=False)

In [None]:
def ricitrain(net, trainLoader,testLoader, loss_fn, optimizer, scheduler, NUM_EPOCHS):
    train_loss = []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        net.train()
        running_loss = 0.0
        with tqdm(trainLoader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                img_clean, img_noisy = data
                img_noisy = img_noisy.to(device)
                img_clean = img_clean.to(device)
                optimizer.zero_grad()
                outputs = net(img_noisy)
                loss = loss_fn(outputs, img_clean)
                # backpropagation
                loss.backward()
                # update the parameters
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss)    
            
            loss = running_loss / len(trainLoader)
            train_loss.append(loss)
        scheduler.step()
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(testLoader, unit="batch") as tepoch:
                for data in tepoch:
                    img_clean, img_noisy= data
                    img_noisy = img_noisy.to(device)
                    img_clean = img_clean.to(device)
                    outputs = net(img_noisy)
                    outputs = torch.clamp(outputs, -.5, .5)
                    loss = loss_fn(outputs, img_clean)
                    val_loss += loss.item()
                    tepoch.set_postfix(loss=loss)    
                
                loss = val_loss / len(testLoader)
                test_loss.append(loss)
        
    
    return train_loss,test_loss         

In [None]:
ricianunet = Unet().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricianunet.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_unet_train_loss,rici_unet_test_loss = ricitrain(ricianunet, rici_dataloader, rici_dataloader_test, loss_fn, optimizer, scheduler, 25)
torch.save(ricianunet.state_dict(), './rician_unet_model.pth')

In [None]:
ricicnndmri = CNNDMRI().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(ricicnndmri.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_cnndmri_train_loss,rici_cnndmri_test_loss = ricitrain(ricicnndmri, rici_dataloader, rici_dataloader_test, loss_fn, optimizer, scheduler, 25)
torch.save(ricicnndmri.state_dict(), './rician_cnndmri_model.pth')

In [None]:
riciwsta = WSTAutoencoder().to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(riciwsta.parameters(), lr=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0)
rici_wsta_train_loss,rici_wsta_test_loss = ricitrain(riciwsta, rici_dataloader, rici_dataloader_test, loss_fn, optimizer, scheduler, 25)
torch.save(riciwsta.state_dict(), './rician_wsta_model.pth')

In [None]:
plt.imshow(img[150], cmap='gray')

In [None]:
corrupted_image, corr_sval, corr_msk = corrupt_data_binary(img[150],spec[150])
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[150], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
denoised = binarycnndmri(corrupted_image)
denoised_spec = torch.fft.fft2(denoised)
denoised_spec = fftshift3d(denoised_spec, ifft=False)
spec_mask = torch.tensor(corr_msk).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
corr_sval = torch.tensor(corr_sval).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
denoised_spec = corr_sval * spec_mask + denoised_spec * (1 - spec_mask)
output = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True))).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()

In [None]:
corrupted_image, corr_sval, corr_msk = corrupt_data_binary(img[150],spec[150])
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[150], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
denoised = binaryunet(corrupted_image)
denoised_spec = torch.fft.fft2(denoised)
denoised_spec = fftshift3d(denoised_spec, ifft=False)
spec_mask = torch.tensor(corr_msk).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
corr_sval = torch.tensor(corr_sval).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
denoised_spec = corr_sval * spec_mask + denoised_spec * (1 - spec_mask)
output = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True))).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()

In [None]:
corrupted_image, corr_sval, corr_msk = corrupt_data_binary(img[150],spec[150])
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[150], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
denoised = binarywsta(corrupted_image)
denoised_spec = torch.fft.fft2(denoised)
denoised_spec = fftshift3d(denoised_spec, ifft=False)
spec_mask = torch.tensor(corr_msk).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
corr_sval = torch.tensor(corr_sval).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
denoised_spec = corr_sval * spec_mask + denoised_spec * (1 - spec_mask)
output = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True))).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()

In [None]:
corrupted_image = corrupt_data_rician(img[150]+0.5,11).clip(0,1)-0.5
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[150], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
output = ricianunet(corrupted_image.float()).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()

In [None]:
corrupted_image = corrupt_data_rician(img[150]+0.5,11).clip(0,1)-0.5
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[150], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
output = ricicnndmri(corrupted_image.float()).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()

In [None]:
corrupted_image = corrupt_data_rician(img[150]+0.5,11).clip(0,1)-0.5
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[150], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
output = riciwsta(corrupted_image.float()).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()

In [None]:
corrupted_image, corr_sval, corr_msk = corrupt_data_binary(img[150],spec[150])
plt.imshow(corrupted_image, cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(img[50], cmap='gray')
plt.axis("off")
plt.show()
corrupted_image =torch.tensor(corrupted_image).unsqueeze(0).unsqueeze(0).to(device)
denoised = model(corrupted_image)
denoised_spec = torch.fft.fft2(denoised)
denoised_spec = fftshift3d(denoised_spec, ifft=False)
spec_mask = torch.tensor(corr_msk).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
corr_sval = torch.tensor(corr_sval).unsqueeze(0).unsqueeze(0).type(torch.complex64).to(device)
denoised_spec = corr_sval * spec_mask + denoised_spec * (1 - spec_mask)
output = torch.real(torch.fft.ifft2(fftshift3d(denoised_spec, ifft=True))).clamp(-.5, .5).squeeze(0).squeeze(0)
output = output.cpu().detach().numpy()
plt.imshow(output, cmap='gray')
plt.axis("off")
plt.show()