## Imports and auxiliar functions


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
from skimage.metrics import peak_signal_noise_ratio
from skimage.transform import radon, iradon
from skimage.filters import gaussian
import matplotlib.pyplot as plt

import time

def get_mgrid(length1, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=length1)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

def plot_images(original, transformed, num_realizations):
    '''Produces a plot of 3 images (Original image, Radon transform and Reconstructed image)
    original: numpy array, original image
    transformed: numpy array, radon transform projection
    num_realizations: int, number of realizations for the radon transform'''
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(original, cmap='gray')
    plt.title('Shepp-Logan Phantom')

    plt.subplot(1, 3, 2)
    plt.imshow(transformed, cmap='gray', extent=(0, 180, 0, num_realizations), aspect='auto')
    plt.title('Radon Transform')
    plt.xlabel('Projection Angle (degrees)')
    plt.ylabel('Realizations')

    reconstructed = iradon(transformed, theta=np.linspace(0., 180., transformed.shape[1]), circle=True)
    plt.subplot(1, 3, 3)
    plt.imshow(reconstructed, cmap='gray')
    plt.title('Reconstructed Image')

    plt.show()

def plot_images2(img, mse, psnr, mssim):
  '''Produces a plot of 5 images (Denoised image, mse error graph, psnr error graph, mssim error graph)
    im: numpy array, reconstructed image
    mse: array, mse error vector
    psnr: array, psnr error vector
    mssim: array, mssim error vector'''

    #diff = np.abs(img - past_img) # dis was to also get the see in which areas the Ã±odel was changing a lot, deviationg when the learning wasnt good

    plt.figure(figsize=(14, 4))


    plt.subplot(1, 4, 1)
    plt.imshow(img, cmap='gray_r')
    plt.colorbar()
    plt.title('Image')

    # plt.subplot(1, 4, 2)
    # plt.imshow(diff, cmap='viridis', aspect='auto')
    # # Add colorbar for reference
    # plt.colorbar()
    # plt.title('Diff')

    plt.subplot(1, 4, 2)
    plt.plot(range(1,len(mse)+1), mse)
    plt.title('MSE')

    plt.subplot(1, 4, 3)
    plt.plot(range(1,len(psnr)+1), psnr)
    plt.title('PSNR')

    plt.subplot(1, 4, 4)
    plt.plot(range(1,len(mssim)+1), mssim)
    plt.title('MSSIM')

    plt.show()

def radon_transform(image, num_angles): # It isn't really needed here
   '''Executes the radon transform
    image: numpy array
    num_angles: int
    '''
    angles = torch.linspace(0.0, 180.0, steps=num_angles)
    radon_images = []
    for angle in angles:
        rotated_image = TF.rotate(image.unsqueeze(0), angle.item())
        projection = torch.sum(rotated_image.squeeze(0), dim=0)
        radon_images.append(projection)

    return torch.stack(radon_images, dim=0)

def get_img_tensor(im, length1):
  '''Transforms a numpy array in a tensor
    im: numpy array
    length1: int, new size for the image
    '''
    img = Image.fromarray(im)
    transform = Compose([
        Resize((length1)),
        ToTensor(),
        #Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
    ])
    img = transform(img)
    return img

def Gaussian_Noise(img_shape, sigma=1):
  '''Generates a gaussian noise
    sinog_shape: int tuple, desired size
    sigma: float, desired std value
    '''
    np.random.seed(0)
    noise = np.random.normal(0, sigma, size =img_shape)
    return noise

def normalize(data):
   '''Data normalization
    data: numpy array
    '''
    max = data.max()
    min = data.min()
    return (data-min)/(max-min)

#Mean Structural Similarity Index Map
def mssim(img1, img2, alpha, beta, gamma):
    """Return the Structural Similarity Map corresponding to input images img1
    and img2
    """

    # Convert to float64 to avoid floating point error and negative values in sigma1_sq or sigma2_sq
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    # Data range
    L = np.max(img2) - np.min(img2)

    # Parameters from Wang et al. 2004
    sigma = 1.5
    K1 = 0.01
    K2 = 0.03
    C1 = (K1*L)**2
    C2 = (K2*L)**2

    # Convolve images (gaussian or uniform filter) to get mean for each patch
    filter_args = {'sigma': sigma, 'truncate': 3.5} # 3.5 is the number of sigmas to match Wang et al. to have filter size=11
    mu1 = gaussian(img1)
    mu2 = gaussian(img2)

    # Multiply images
    mu1_sq = mu1*mu1
    mu2_sq = mu2*mu2
    mu1_mu2 = mu1*mu2

    # Convolve images (gaussian or uniform filter) to get variance and covariance for each patch. Remove negative values coming from floating point errors

    sigma1_sq = gaussian(img1*img1) - mu1_sq
    sigma1_sq[sigma1_sq < 0] = 0
    sigma2_sq = gaussian(img2*img2) - mu2_sq
    sigma2_sq[sigma2_sq < 0] = 0
    sigma12 = gaussian(img1*img2) - mu1_mu2

    # Compute luminance, contrast and structure for each patch
    luminance =((2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1))**alpha
    contrast=((2*np.sqrt(sigma1_sq*sigma2_sq) + C2)/(sigma1_sq + sigma2_sq + C2))**beta
    structure=((2*sigma12 + C2)/(2*np.sqrt(sigma1_sq*sigma2_sq) + C2))**gamma

    # Compute MSSIM
    MSSIM=np.mean(luminance*contrast*structure)
    return MSSIM

## Siren Model

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()

        self.net = []
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)

                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()

                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else:
                x = layer(x)

                if retain_grad:
                    x.retain_grad()

            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

## Data Model

In [None]:
class ImageFitting(Dataset):
    def __init__(self, img, length1):
        super().__init__()
        img = get_img_tensor(img, length1)
        self.pixels = img.permute(1, 2, 0).view(-1, 1)
        self.coords = get_mgrid(length1, 2)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        if idx > 0: raise IndexError

        return self.coords, self.pixels

## Initializing model and data

In [None]:
fo = open("/content/out_beta00102522_it30.img", "rb")
img = np.fromfile(fo, dtype=np.float32).reshape((400,400))
img_norm = normalize(img)
plt.imshow(img_norm, cmap="gray_r")
plt.colorbar()

In [None]:
fo = open("/content/gt_tep.img", "rb")
img_gt = np.fromfile(fo, dtype=np.float32).reshape((400,400))
img_gt_norm = normalize(img_gt)
# noise = Gaussian_Noise(img.shape, 10).astype("float32")
# img_bruite = np.clip(img + noise, img.min(), img.max())
# img_bruite = normalize(img_bruite)
plt.imshow(img_gt, cmap="gray_r")
plt.colorbar()

In [None]:
peak_signal_noise_ratio(normalize(img),img_gt_norm)

In [None]:
# Establish seed for reproducibility
torch.manual_seed(42)

#img = skimage.data.shepp_logan_phantom()

phantom = ImageFitting(img_norm, 400)
dataloader = DataLoader(phantom, batch_size=1, pin_memory=True, num_workers=0)

img_siren = Siren(in_features=2, out_features=1, hidden_features=256,
                  hidden_layers=3, outermost_linear=True, first_omega_0=50, hidden_omega_0=50)
img_siren.cuda()

## Running Model

In [None]:
total_steps = 3001 # Since the whole image is our dataset, this just means 500 gradient descent steps.
steps_til_summary = 50

optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())
#optim = torch.optim.Adagrad(lr=1e-3, params=img_siren.parameters())

model_input, ground_truth = next(iter(dataloader))
model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

mse_losses = []
psnr_losses = []
img_mse_losses = []
img_psnr_losses = []
img_mssim_losses = []

for step in range(total_steps):

    if step >0 :
      past_output = model_output

    model_output, coords = img_siren.forward(model_input)


    mse_loss = ((model_output - ground_truth)**2).mean()
    psnr_loss = peak_signal_noise_ratio(ground_truth.cpu().detach().numpy(), model_output.cpu().detach().numpy())

    mse_losses.append(mse_loss.item())
    psnr_losses.append(psnr_loss)

    img_mse_loss = ((model_output.cpu().view(400,400).detach().numpy() - img_gt_norm)**2).mean()
    img_psnr_loss = peak_signal_noise_ratio(img_gt_norm, normalize(model_output.cpu().view(400,400).detach().numpy().astype('float32')))
    img_mssim_loss=mssim(normalize(model_output.cpu().view(400,400).detach().numpy().astype('float32')), img_gt_norm, 0.5, 1, 1)

    if step == 0:
      best_img_psnr = model_output # for saving best img with psnr
      best_psnr = img_psnr_loss

      best_img_MSSIM = model_output # for saving best img with MSSIM
      best_MSSIM = img_mssim_loss
    else:
      if best_psnr < img_psnr_loss:
        best_img_psnr = model_output # for saving best img with psnr
        best_psnr = img_psnr_loss

      if best_MSSIM < img_mssim_loss:
        best_img_MSSIM = model_output # for saving best img with psnr
        best_MSSIM = img_mssim_loss

    img_mse_losses.append(img_mse_loss.item())
    img_psnr_losses.append(img_psnr_loss)
    img_mssim_losses.append(img_mssim_loss)

    if step > 0 and not step % steps_til_summary:
        print("Step %d, Total loss %0.6f" % (step, mse_loss))

        plot_images2(model_output.cpu().view(400,400).detach().numpy(), img_mse_losses, img_psnr_losses, img_mssim_losses)

    optim.zero_grad()
    mse_loss.backward()
    optim.step()

print("MSE:")
print("Last epoch - %0.5f" %(img_mse_losses[-1]))
print("Minimum value - %0.5f" %(min(img_mse_losses)))
print()

print("PSNR:")
print("Last epoch - %0.5f" %(img_psnr_losses[-1]))
print("Maximum value - %0.5f" %(max(img_psnr_losses)))
print()

print("MSSIM - IMG:")
print("Last epoch - %0.5f" %(img_mssim_losses[-1]))
print("Maximum value - %0.5f" %(max(img_mssim_losses)))
print()

plt.imshow(best_img_psnr.cpu().view(400,400).detach().numpy(), cmap='gray_r', vmin=0, vmax=1)
plt.title('Best Reconstructed Image - PSNR')
plt.axis('off')
plt.show()

plt.imshow(best_img_MSSIM.cpu().view(400,400).detach().numpy(), cmap='gray_r', vmin=0, vmax=1)
plt.title('Best Reconstructed Image - MSSIM')
plt.axis('off')
plt.show()


print("When the best PSNR is : ", best_psnr, "The best MSSIM is : ", img_mssim_losses[img_psnr_losses.index(best_psnr)])

print("When the best MSSIM is : ", best_MSSIM, "The best PSNR is : ", img_psnr_losses[img_mssim_losses.index(best_MSSIM)])

print("Additionally the best mse and psnr from the sinogram respectively")

print(min(img_mse_losses))
print(max(img_psnr_losses))

In [None]:
for step in range(1000):

    if step >0 :
      past_output = model_output

    model_output, coords = img_siren.forward(model_input)


    mse_loss = ((model_output - ground_truth)**2).mean()
    psnr_loss = peak_signal_noise_ratio(ground_truth.cpu().detach().numpy(), model_output.cpu().detach().numpy())

    mse_losses.append(mse_loss.item())
    psnr_losses.append(psnr_loss)

    img_mse_loss = ((model_output.cpu().view(400,400).detach().numpy() - img_gt_norm)**2).mean()
    img_psnr_loss = peak_signal_noise_ratio(img_gt_norm, normalize(model_output.cpu().view(400,400).detach().numpy().astype('float32')))
    img_mssim_loss=mssim(normalize(model_output.cpu().view(400,400).detach().numpy().astype('float32')), img_gt_norm, 0.5, 1, 1)

    if step == 0:
      best_img_psnr = model_output # for saving best img with psnr
      best_psnr = img_psnr_loss

      best_img_MSSIM = model_output # for saving best img with MSSIM
      best_MSSIM = img_mssim_loss
    else:
      if best_psnr < img_psnr_loss:
        best_img_psnr = model_output # for saving best img with psnr
        best_psnr = img_psnr_loss

      if best_MSSIM < img_mssim_loss:
        best_img_MSSIM = model_output # for saving best img with psnr
        best_MSSIM = img_mssim_loss

    img_mse_losses.append(img_mse_loss.item())
    img_psnr_losses.append(img_psnr_loss)
    img_mssim_losses.append(img_mssim_loss)

    if step > 0 and not step % steps_til_summary:
        print("Step %d, Total loss %0.6f" % (step, mse_loss))

        plot_images2(model_output.cpu().view(400,400).detach().numpy(), img_mse_losses, img_psnr_losses, img_mssim_losses)

    optim.zero_grad()
    mse_loss.backward()
    optim.step()

print("MSE:")
print("Last epoch - %0.5f" %(img_mse_losses[-1]))
print("Minimum value - %0.5f" %(min(img_mse_losses)))
print()

print("PSNR:")
print("Last epoch - %0.5f" %(img_psnr_losses[-1]))
print("Maximum value - %0.5f" %(max(img_psnr_losses)))
print()

print("MSSIM - IMG:")
print("Last epoch - %0.5f" %(img_mssim_losses[-1]))
print("Maximum value - %0.5f" %(max(img_mssim_losses)))
print()

plt.imshow(best_img_psnr.cpu().view(400,400).detach().numpy(), cmap='gray_r', vmin=0, vmax=1)
plt.title('Best Reconstructed Image - PSNR')
plt.axis('off')
plt.show()

plt.imshow(best_img_MSSIM.cpu().view(400,400).detach().numpy(), cmap='gray_r', vmin=0, vmax=1)
plt.title('Best Reconstructed Image - MSSIM')
plt.axis('off')
plt.show()


print("When the best PSNR is : ", best_psnr, "The best MSSIM is : ", img_mssim_losses[img_psnr_losses.index(best_psnr)])

print("When the best MSSIM is : ", best_MSSIM, "The best PSNR is : ", img_psnr_losses[img_mssim_losses.index(best_MSSIM)])

print("Additionally the best mse and psnr from the sinogram respectively")

print(min(img_mse_losses))
print(max(img_psnr_losses))

In [None]:
len(mse_losses)

In [None]:
min(mse_losses)

In [None]:
max(psnr_losses)

In [None]:
torch.max(sinog_tensor)

In [None]:
torch.save(img_siren.state_dict(),'img_rec_1e4.pth')

## Testing

In [None]:
#!pip install -q condacolab
import condacolab
condacolab.install()