In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np
import os
import time
from tqdm import tqdm

import nibabel as nib
from scipy.fftpack import fft, ifft, fft2, ifft2, fftshift, ifftshift
import PIL.Image
import pickle


In [None]:
#As in the paper N2N
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)
def fftshift2d(x, ifft=False):
    assert (len(x.shape) == 2) and all([(s % 2 == 1) for s in x.shape])
    s0 = (x.shape[0] // 2) + (0 if ifft else 1)
    s1 = (x.shape[1] // 2) + (0 if ifft else 1)
    x = np.concatenate([x[s0:, :], x[:s0, :]], axis=0)
    x = np.concatenate([x[:, s1:], x[:, :s1]], axis=1)
    return x

In [None]:
# ruta = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_train-001.pkl"
# ruta_test = "C:/Users/javit/Desktop/MRI datasets/datasets/ixi_valid.pkl"
ruta = "C:/Users/javit/Desktop/N2N/datasets/ixi_train.pkl"
ruta_test = "C:/Users/javit/Desktop/N2N/datasets/ixi_valid.pkl"
img, spec = load_pkl(ruta)
img=img[:,:-1,:-1] #images are now 255,255
img = img.astype(np.float32) / 255.0 - 0.5 # normalize and make sure they are in range [-.5,.5]
test_img, test_spec = load_pkl(ruta_test)
test_img=test_img[:,:-1,:-1]
test_img=test_img.astype(np.float32) / 255.0 - 0.5


p_at_edge=0.025
h = [s // 2 for s in (255,255)] #255
r = [np.arange(s, dtype=np.float32) - h for s, h in zip((255,255), h)]
r = [x ** 2 for x in r]
r = (r[0][:, np.newaxis] + r[1][np.newaxis, :]) ** .5
m = (p_at_edge ** (1./h[1])) ** r
bern_mask = m

In [None]:
def corrupt_data(img, spec):
    global bern_mask
    mask = bern_mask
    # print('Bernoulli probability at edge = %.5f' % mask[h[0], 0])
    # print('Average Bernoulli probability = %.5f' % np.mean(mask))
    keep = (np.random.uniform(0.0, 1.0, size=spec.shape)**2 < mask)
    keep = keep & keep[::-1, ::-1]
    sval = spec * keep
    smsk = keep.astype(np.float32)
    spec = fftshift2d(sval / (mask + ~keep), ifft=True) # Add 1.0 to not-kept values to prevent div-by-zero.
    img = np.real(np.fft.ifft2(spec)).astype(np.float32)
    return img, sval, smsk

In [None]:
corr_img=np.zeros(img.shape)
corr_val=np.zeros(img.shape).astype(np.complex64)
corr_mask=np.zeros(img.shape)
psnr=np.zeros(img.shape[0])
for i in range(len(img)):
    corr_img[i],corr_val[i],corr_mask[i]= corrupt_data(img[i],spec[i])
    psnr[i]=10*np.log10(1/np.mean((img[i]-corr_img[i])**2))
print(psnr.mean())
    

In [None]:
plt.imshow(corr_img[123],cmap='gray')
plt.axis("off")
plt.show

In [None]:
plt.imshow(corr_img[123].clip(-0.5,0.5),cmap="gray")
plt.axis("off")
plt.show()

In [None]:
plt.imshow(img[123],cmap='gray')
plt.axis("off")

In [None]:
from kymatio import Scattering2D

In [None]:
scattering=Scattering2D(J=1,L=8,shape=(255,255))

In [None]:
corr_coefs=scattering(corr_img[123].clip(-0.5,0.5))

In [None]:
coefs=scattering(img[123])

In [None]:
plt.figure()
for i in range(9):
    if i==0:
        plt.imshow(coefs[i],cmap='gray')
        plt.title("Low Pass")
        plt.axis("off")
    else:
        plt.figure()
        plt.imshow(coefs[i],cmap='gray')
        plt.title("θ = "+str(i))
        plt.axis("off")

In [None]:
corr_coefs=scattering(corr_img[123].clip(-0.5,0.5))

In [None]:
plt.figure()
for i in range(9):
    if i==0:
        plt.imshow(corr_coefs[i],cmap='gray')
        plt.title("Low Pass")
        plt.axis("off")
    else:
        plt.figure()
        plt.imshow(corr_coefs[i],cmap='gray')
        plt.title("θ = "+str(i))
        plt.axis("off")

In [None]:


n_cols = 9

fig, axes = plt.subplots(2, n_cols, figsize=(2*n_cols, 4))

for i in range(n_cols):
    # Fila superior
    axes[0, i].imshow(coefs[i], cmap='gray')
    axes[0, i].axis('off')
    # Agrega títulos sólo arriba
    if i == 0:
        axes[0, i].set_title("Low Pass")
    else:
        axes[0, i].set_title(f"θ = {i}")

    # Fila inferior
    axes[1, i].imshow(corr_coefs[i], cmap='gray')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()


In [None]:
def rician_noise(img,noise_percent):
    sigma =(noise_percent/100)*img.max().item()
    noise1 = np.random.normal(0,sigma,img.shape)
    noise2 = np.random.normal(0,sigma,img.shape)
    noisy_img = np.sqrt((img+noise1)**2+noise2**2)
    return noisy_img

In [None]:
corr_rici_img=np.zeros(img.shape)
psnr_rici=np.zeros(img.shape[0])
clean_img= img+0.5
for i in range(len(img)):
    corr_rici_img[i]= rician_noise(clean_img[i],11)
    psnr_rici[i]=10*np.log10(1/np.mean((clean_img[i]-corr_rici_img[i])**2))
print(psnr_rici.mean())
    

In [None]:
plt.imshow(corr_rici_img[50],cmap='gray')
plt.axis("off")
plt.show

In [None]:
plt.imshow(img[50],cmap='gray')
plt.axis("off")
plt.show

In [None]:
plt.imshow(bern_mask,cmap='gray')
plt.axis("off")
plt.show()

In [None]:
plt.imshow(corr_mask[50],cmap='gray')
plt.axis("off")
plt.show

In [None]:
plt.imshow(np.log(np.abs(corr_val[50])+1),cmap='gray')
plt.axis("off")
plt.show

In [None]:
plt.imshow(np.log(np.abs(spec[50])+1),cmap='gray')
plt.axis("off")
plt.show

In [None]:
def psnr(mse):
    return 10*np.log10(1/mse)
