In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
from skimage.util import random_noise
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt
import time

In [None]:
from Models.SIREN import Siren

In [None]:
def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

In [None]:
sidelen = 3
dim = 2
tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
torch.meshgrid(*tensors)

## Differential Operators

In [None]:
def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)


def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
    return div


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

## Get Image

In [None]:
def get_cameraman_tensor(sidelength):
    img = Image.fromarray(skimage.data.camera())        
    transform = Compose([
        Resize(sidelength),
        ToTensor(),
        Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
    ])
    img = transform(img)
    return img

In [None]:
print(get_cameraman_tensor(256).shape)

## Fit the Image

In [None]:
class ImageFitting(Dataset):
    def __init__(self, sidelength, noisy=False, mode = None, var = None, amount=None):
        super().__init__()
        img = get_cameraman_tensor(sidelength)
        if noisy:
            if mode=='gaussian':
                img = torch.from_numpy(random_noise(img, mode=mode, var= var))
            elif mode=='s&p':
                img = torch.from_numpy(random_noise(img, mode=mode, amount=amount))
            else:
                img = torch.from_numpy(random_noise(img, mode='localvar'))
                
        self.pixels = img.permute(1, 2, 0).view(-1, 1)
        self.coords = get_mgrid(sidelength, 2)

    def __len__(self):
        return 1

    def __getitem__(self, idx):    
        if idx > 0: raise IndexError
    
        return self.coords, self.pixels

# Ground Truths for Different Noises

In [None]:
# Assuming you have the ImageFitting instance 'cameraman'
cameraman = ImageFitting(256, noisy=False)

# Get the image tensor and reshape it to 2D
image_tensor = cameraman.pixels.view(256, 256)

# Convert the tensor to numpy array for visualization
image_array = image_tensor.numpy()
noisy_image_r = random_noise(image_array, mode='localvar')
noisy_image_g = random_noise(image_array, mode='gaussian', var= 0.2)
noisy_image_p = random_noise(image_array, mode='s&p', amount=0.5)
# Use matplotlib to visualize the image

fig, axes = plt.subplots(1,4, figsize=(18,6))
axes[0].imshow(image_array, cmap='gray')
axes[0].set_title("Original Image")
axes[1].imshow(noisy_image_r, cmap='gray')
axes[1].set_title("Noisy Image (random noise)")
axes[2].imshow(noisy_image_g, cmap='gray')
axes[2].set_title("Noisy Image (gaussian noise)")
axes[3].imshow(noisy_image_p, cmap='gray')
axes[3].set_title("Noisy Image (s and p)")
# plt.imshow(image_array.cpu().view(256,256).detach().numpy())
plt.show()


## GRADIENT AND LAPLACIAN COMPUTATION

In [None]:
import scipy.ndimage

def compute_gradient(image):
    # Compute gradients along x and y directions
    dx = scipy.ndimage.sobel(image, axis=0, mode='constant')
    dy = scipy.ndimage.sobel(image, axis=1, mode='constant')
    return dx, dy

def compute_laplacian(image):
    # Compute the Laplacian using a convolution with a specific kernel
    laplacian_kernel = np.array([[0, 1, 0],
                                 [1, -4, 1],
                                 [0, 1, 0]])
    laplacian = scipy.ndimage.convolve(image, laplacian_kernel, mode='constant')
    return laplacian



In [None]:
cameraman = ImageFitting(256, noisy=True, mode='localvar')
#cameraman = ImageFitting(256, noisy=True, mode='s&p',amount =0.1)
#cameraman = ImageFitting(256)

# Get the image tensor and reshape it to 2D
image_tensor = cameraman.pixels.view(256, 256)

# Convert the tensor to numpy array for visualization
image_array = image_tensor.numpy()

# Assuming 'image_tensor' is your image tensor and already converted to a numpy array
# For example, if your tensor is a PyTorch tensor, you can convert it to numpy array using image_tensor.numpy()

# Compute gradient
dx, dy = compute_gradient(image_array)

# Compute Laplacian
laplacian = compute_laplacian(image_array)

fig, axes = plt.subplots(1,4, figsize=(18,6))
axes[0].imshow(image_array, cmap='gray')
axes[0].set_title("Original Image")
axes[1].imshow(dx, cmap='gray')
axes[1].set_title("Gradient along x")
axes[2].imshow(dy, cmap='gray')
axes[2].set_title("Gradient along y")
axes[3].imshow(laplacian, cmap='gray')
axes[3].set_title("Laplacian")

# Load the dataset

In [None]:
#cameraman = ImageFitting(256)
#cameraman = ImageFitting(256, noisy=True, mode='gaussian', var=0.2)
#cameraman = ImageFitting(256, noisy=True, mode='s&p',amount =0.1)
cameraman = ImageFitting(256, noisy=True, mode='localvar')

dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)

img_siren = Siren(in_features=2, out_features=1, hidden_features=256, 
                  hidden_layers=3, outermost_linear=True)
img_siren.cuda()

In [None]:
img = get_cameraman_tensor(256)
img_n = torch.from_numpy(random_noise(img))
print(img_n.shape)

# TRAINING LOOP

In [None]:
total_steps = 300 # Since the whole image is our dataset, this just means 500 gradient descent steps.
steps_til_summary = 10

optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())

model_input, ground_truth = next(iter(dataloader))
model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

loss_array = []

for step in range(total_steps):
    model_output, coords = img_siren(model_input)    
    loss = ((model_output - ground_truth)**2).mean()
    loss_array.append(loss.detach().cpu().numpy())
    
    if not step % steps_til_summary:
        print("Step %d, Total loss %0.6f" % (step, loss))
        img_grad_x, image_grad_y = compute_gradient(model_output.cpu().view(256 ,256).detach().numpy())
        img_laplacian = compute_laplacian(model_output.cpu().view(256 ,256).detach().numpy())

        fig, axes = plt.subplots(1 ,4, figsize=(18 ,6))
        axes[0].imshow(model_output.cpu().view(256 ,256).detach().numpy(), cmap='gray')
        axes[0].set_title("Trained Image")
        axes[1].imshow(img_grad_x, cmap='gray')
        axes[1].set_title("Gradient along x")
        axes[2].imshow(image_grad_y, cmap='gray')
        axes[2].set_title("Gradient along y")
        axes[3].imshow(img_laplacian, cmap='gray')
        axes[3].set_title("Laplacian")
        plt.show()

    optim.zero_grad()
    loss.backward()
    optim.step()


In [None]:
print(loss.detach().cpu().numpy())

# Plot the loss curve

In [None]:
plt.plot(loss_array)
plt.xlabel('iteration')
plt.ylabel('Loss')
plt.show()

In [None]:
with torch.no_grad():
    coords = get_mgrid(2**10, 1) * 5 * np.pi
    
    sin_1 = torch.sin(coords)
    sin_2 = torch.sin(coords * 2)
    sum = sin_1 + sin_2
    
    fig, ax = plt.subplots(figsize=(16,2))
    ax.plot(coords, sum)
    ax.plot(coords, sin_1)
    ax.plot(coords, sin_2)
    plt.title("Rational multiple")
    plt.show()
    
    sin_1 = torch.sin(coords)
    sin_2 = torch.sin(coords * np.pi)
    sum = sin_1 + sin_2
    
    fig, ax = plt.subplots(figsize=(16,2))
    ax.plot(coords, sum)
    ax.plot(coords, sin_1)
    ax.plot(coords, sin_2)
    plt.title("Pseudo-irrational multiple")
    plt.show()

In [None]:
with torch.no_grad():
    out_of_range_coords = get_mgrid(1024, 2) * 50
    model_out, _ = img_siren(out_of_range_coords.cuda())
    
    fig, ax = plt.subplots(figsize=(16,16))
    ax.imshow(model_out.cpu().view(1024,1024).numpy())
    plt.show()