# Siren Denoising
This is a colab notebook that implements unsupervised image denoising with SIREN with several variants. The code corresponds to the paper *Denoising Capacity of Implicit Image Representation*. This notebook is modified from the [official implementation](https://github.com/vsitzmann/siren)'s `explore_siren.ipynb` notebook. For questions or comments, please contact me at leozdong@stanford.edu

## Initial Setup

This defines the LPIPS metric we will use to measure perceptual similarity

In [None]:
!pip install lpips
import lpips
loss_fn_alex = lpips.LPIPS(net='alex')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt

import time

from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid
  
def rgb_float2uint(rgb):
  # `rgb` in range [0, 1]
  return (np.clip(rgb, a_min=0, a_max=1) * 255).astype(np.uint8)

def lpips_np(img_targ, img):
  with torch.no_grad():
    return loss_fn_alex(torch.tensor(img_targ).permute(2, 0, 1), torch.tensor(img).permute(2, 0, 1)).item()

Now, we code up the sine layer, which will be the basic building block of SIREN. This is a much more concise implementation than the one in the main code, as here, we aren't concerned with the baseline comparisons.

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate
    
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords        

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

Finally, we define Siren with spline positional encoding. This following class definition is modified from https://github.com/microsoft/SplinePosEnc/blob/main/models.py.

In [None]:
class PosProj(nn.Module):
  def __init__(self, in_dim, out_dim):
    super().__init__()
    assert(out_dim > in_dim)
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.channel = out_dim - in_dim
    self.register_buffer('proj', torch.Tensor(in_dim, self.channel))
    self.reset_parameters()

  def reset_parameters(self):
    with torch.no_grad():
      proj = torch.randn_like(self.proj)
      scale = self.in_dim ** 0.5 # TODO: use small scale
      scale = torch.norm(proj, dim=0, keepdim=True) * scale
      proj = proj / (1.0e-6 + scale)
      self.proj.copy_(proj)

  def forward(self, coords):
    proj = torch.flatten(coords, end_dim=1).mm(self.proj)
    proj = proj.view(list(coords.size())[:-1] + [self.channel])
    output = torch.cat([coords, proj], dim=-1)
    return output


class OptPosEnc(nn.Module):
  def __init__(self, in_features, code_num=256, code_channel=64):
    super().__init__()
    self.in_features = in_features
    self.out_features = code_channel
    self.code_num = code_num

    code_size = [code_channel, in_features * code_num]
    self.shape_code = nn.Parameter(torch.Tensor(*code_size))
    self.reset_parameters()

  def reset_parameters(self,):
    nn.init.xavier_uniform_(self.shape_code)

  def forward(self, coords):
    return self._forward(coords, self.shape_code)

  def _forward(self, coords, shape_code):
    pt_num, in_features = coords.size(1), coords.size(2)
    assert in_features == self.in_features
    code_num = shape_code.size(1) // in_features
    mul = [[[[code_num * i] for i in range(in_features)]]] # [1, 1, D, 1]
    mul = torch.tensor(mul, dtype=torch.int64, device=coords.device)
    mask = torch.tensor([[[[0, 1]]]], dtype=torch.float32, device=coords.device)

    coords = (coords + 1.0) * ((code_num - 1) / 2.0) # [-1, 1] -> [0, code_num-1]
    corners = torch.floor(coords).detach()    # [1, N, D]
    corners = corners.unsqueeze(-1) + mask    # [1, N, D, 2]
    index = corners.to(torch.int64) + mul     # [1, N, D, 2]
    coordsf = coords.unsqueeze(-1) - corners  # [1, N, D, 2], local coords [-1, 1]
    weights = 1.0 - torch.abs(coordsf)        # (1, N, D, 2)

    coords_code = torch.index_select(shape_code, 1, index.view(-1))
    coords_code = coords_code.view(-1, pt_num, in_features, 2) # (C, N, D, 2)
    output = torch.sum(coords_code * weights, dim=(-2, -1), keepdim=True)
    output = output.squeeze(-1).permute(2, 1, 0)
    return output

  def upsample(self, size=64):
    code = self.shape_code.view(self.out_features, self.in_features, self.code_num)
    code = code.permute(1, 0, 2)
    output = torch.nn.functional.upsample(code, size=size, mode='linear',
                                          align_corners=True)
    output = output.permute(1, 0, 2)
    output = output.reshape(self.out_features, -1)
    return output


class PosEncSiren(nn.Module):
  def __init__(self, in_features=2, out_features=3, hidden_layers=3,
               hidden_features=256, projs=32):
    super().__init__()
    assert projs >= in_features
    self.proj = PosProj(in_features, projs)
    in_features = projs
    self.pos_enc = OptPosEnc(in_features)
    self.net = Siren(in_features=self.pos_enc.out_features, 
                     out_features=out_features, hidden_features=hidden_features, 
                     hidden_layers=hidden_layers, outermost_linear=True)
    self.reset_parameters()

  def reset_parameters(self):
    with torch.no_grad():
      shape_code = torch.zeros_like(self.pos_enc.shape_code)
      channel, num = self.pos_enc.shape_code.shape
      code_num, in_features = self.pos_enc.code_num, self.pos_enc.in_features

      delta =  2./ (code_num-1)
      ch = channel // in_features
      t = torch.arange(-1, 1 + 0.1 * delta, step=delta)
      t = t.unsqueeze(0).repeat(ch, 1)
      for i in range(in_features):
        n = (torch.rand_like(t) - 0.5) * 1.0e-2
        shape_code[i*ch:(i+1)*ch, i*code_num:(i+1)*code_num] = t + n
      self.pos_enc.shape_code.copy_(shape_code)
  
  def forward(self, input):
    coords = self.proj(input)
    enc = self.pos_enc(coords)
    output = self.net(enc)[0]
    return output, input.clone().detach().requires_grad_(True)


And finally, differential operators that allow us to leverage autograd to compute gradients, the laplacian, etc. For this project, we only use gradient to calculat the TV regularization.

In [None]:
def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)


def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
    return div


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True, allow_unused=True)[0]
    return grad

## Experiments Setup

We will use an image from the BSDS300 dataset. First, let's set up the image loader and add noise to it.

In [None]:
import os
import skimage.io as io
bsds300_dir = '/content/drive/My Drive/Colab Notebooks/BSDS300/images'
output_dir = '/content/drive/My Drive/Colab Notebooks/siren_out'

def get_bsds_tensor(sidelength, split='test', id='108005'):
  dir = os.path.join(bsds300_dir, split, f'{id}.jpg')
  img = torch.tensor(io.imread(dir) / 255).float()

  # Resize image
  crop_x_l = (img.shape[0] - sidelength) // 2
  crop_x_r = sidelength + crop_x_l
  crop_y_l = (img.shape[1] - sidelength) // 2
  crop_y_r = sidelength + crop_y_l
  img = img[crop_x_l:crop_x_r, crop_y_l:crop_y_r, :]

  return img

In [None]:
img = get_bsds_tensor(256)
print(img.shape)
fig = plt.figure()
plt.imshow(img)

Add Gaussian noise to this image. Show the resulting PSNR between the noisy and clean image, which will be the baseline.

In [None]:
def get_noisy_image(img, sigma):
    """Adds Gaussian noise to an image.
    
    Args: 
        img: image, torch.tensor with values from 0 to 1
        sigma: std of the noise
    """
    noise = torch.normal(torch.zeros_like(img), sigma * torch.ones_like(img))
    img_noisy = torch.clip(img + noise, 0, 1).float()
    #img_noisy = img + noise

    return img_noisy

In [None]:
img_noisy = get_noisy_image(img, 0.1).numpy()
fig = plt.figure()
plt.imshow(img_noisy)
psnr_noisy_clean = compare_psnr(img.numpy(), img_noisy)
print(f"PSNR between clean and noisy image: {psnr_noisy_clean}")
lpips_noisy_clean = lpips_np(img.numpy(), img_noisy)
print(f"LPIPS between clean and noisy image: {lpips_noisy_clean}")

In [None]:
io.imsave(os.path.join(output_dir, 'img_noisy.jpg'), img_noisy)

In [None]:
class ImageFitting(Dataset):
    def __init__(self, sidelength, split='test', id='108005'):
        super().__init__()
        self.clean_pixels = get_bsds_tensor(sidelength, split, id)
        self.coords = get_mgrid(sidelength, 2)

        # Add noise
        self.pixels = get_noisy_image(self.clean_pixels, 0.1)

        # Normalize as the target to train on
        #transform = Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
        #self.pixels = transform(self.pixels)

        # Flatten spatial dimension for training
        self.pixels = self.pixels.reshape(-1, 3)

    def __len__(self):
        return 1

    def __getitem__(self, idx):    
        if idx > 0: raise IndexError
            
        return self.coords, self.pixels
      
    def get_clean_img(self):
      return self.clean_pixels

We define the TV prior loss that can be used as regularization

In [None]:
def image_mse_TV_prior(k1, model, model_output, coords, gt_img):
  """Calculate loss with TV prior regularizaton
  Args:
    k1 (float): Weight of TV regularization term
    model (nn.Module): SIREN model
    model_output (tensor): Output values at query `coords`
    coords (tensor): Coordinates that we query the model to get `model_output`
    gt_img (tensor): Ground-truth image values at `coords`
  """
  # Query random coordinates for TV loss calculation
  coords_rand = 2 * (torch.rand((coords.shape[0],
                                 coords.shape[1] // 2,
                                 coords.shape[2])).cuda() - 0.5)
  model_out_rand, model_in_rand = model(coords_rand)

  return {'img_loss': ((model_output - gt_img) ** 2).mean(),
          'prior_loss': k1 * (torch.abs(gradient(
            model_out_rand, model_in_rand))).mean()}

Let's instantiate the dataset.

In [None]:
tiger = ImageFitting(256)
img_clean = tiger.get_clean_img().numpy()
dataloader = DataLoader(tiger, batch_size=1, pin_memory=True, num_workers=0)
model_input, ground_truth = next(iter(dataloader))
model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

This is the main training loop for image fitting. With different `loss` parameter, we can either fit the noisy image directly, or add regularizaiton.

In [None]:
def train(img_siren, loss_type='fidelity', k1=1e-4, total_steps=1000, steps_til_summary=20, early_stop=600, spe=False, lr=1e-4):
  assert loss_type in ('fidelity', 'tv')

  optim = torch.optim.Adam(lr=lr, params=img_siren.parameters())
  model_input, ground_truth = next(iter(dataloader))

  # For visualization
  img_in = ground_truth.view(256, 256, -1).numpy()
  bbox_caption = {'facecolor': 'white', 'pad': 8}

  model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

  # Save for record
  psnrs_clean = []
  psnrs_noisy = []
  lpipss_clean = []
  lpipss_noisy = []
  tvs = []

  best_psnr_clean = -float('inf')
  best_step = 0

  model_tag = 'spe_' if spe else ''

  for step in range(total_steps + 1):
      model_output, coords = img_siren(model_input)
      if loss_type == 'fidelity':
        loss = ((model_output - ground_truth)**2).mean()
      else:
        loss_dict = image_mse_TV_prior(k1, img_siren, model_output, coords, ground_truth)
        loss = loss_dict['img_loss'] + loss_dict['prior_loss']
        tvs.append(loss_dict['prior_loss'].detach().cpu().numpy() / k1)
      
      # Compute PSNR metric
      img_out = model_output.cpu().view(256, 256, -1).detach().numpy()
      psnr_noisy = compare_psnr(img_in, img_out)
      psnrs_noisy.append(psnr_noisy)
      psnr_clean = compare_psnr(img_clean, img_out)
      psnrs_clean.append(psnr_clean)

      # TODO: have a better stopping criterion because we cannot assume access to
      # underlying clean image
      if best_psnr_clean < psnr_clean:
        best_psnr_clean = psnr_clean
        best_step = step
        torch.save(img_siren.state_dict(), 
                   os.path.join(output_dir, f"{loss_type}_{model_tag}best.pt"))
       
       # Save early stop model
      if step == early_stop or step == total_steps:
        torch.save(img_siren.state_dict(), 
                   os.path.join(output_dir, f"{loss_type}_{model_tag}step{early_stop}.pt"))

      if not step % steps_til_summary:
          print("Step %d, Total loss %0.6f" % (step, loss))
          if spe:
            img_grad = torch.zeros_like(model_output)
          else:
            img_grad = gradient(model_output, coords)
          # img_laplacian = laplace(model_output, coords)
          
          # TODO: have better stopping criterion?
          # Save best model?
          # Compute LPIPS metric
          with torch.no_grad():
            lpips_noisy = lpips_np(img_in, img_out)
            lpipss_noisy.append(lpips_noisy)
            lpips_clean = lpips_np(img_clean, img_out)
            lpipss_clean.append(lpips_clean)
          
          # Visualization
          fig, axes = plt.subplots(1, 4, figsize=(24,6))
          axes[0].imshow(img_out)
          axes[0].text(12, 20, f'Model output. PSNR_noisy={round(psnr_noisy, 2)}, PSNR_clean={round(psnr_clean, 2)}', bbox=bbox_caption)
          axes[0].text(12, 50, f'LPIPS_noisy={round(lpips_noisy, 2)}, LPIPS_clean={round(lpips_clean, 2)}', bbox=bbox_caption)
          axes[1].imshow(img_grad.norm(dim=-1).cpu().view(256, 256).detach().numpy())
          axes[1].text(12, 20, 'Model output gradient', bbox=bbox_caption)
          #axes[2].imshow(img_laplacian.cpu().view(256, 256).detach().numpy())
          axes[2].imshow(img_in)
          axes[2].text(12, 20, f'Model target (noisy). PSNR_clean={round(psnr_noisy_clean, 2)}', bbox=bbox_caption)
          axes[2].text(12, 50, f'LPIPS_clean={round(lpips_noisy_clean, 2)}', bbox=bbox_caption)
          axes[3].imshow(img_clean)
          axes[3].text(12, 20, 'Clean image', bbox=bbox_caption)
          plt.savefig(os.path.join(output_dir, loss_type, f'{loss_type}_{model_tag}step{step}.png'), dpi=300)

          # Show image at lower frequency
          if not step % (steps_til_summary * 5):
            plt.show()
          
          plt.close()

      optim.zero_grad()
      loss.backward()
      optim.step()

  return {'psnrs_clean': psnrs_clean, 'psnrs_noisy': psnrs_noisy, 
          'lpipss_clean': lpipss_clean, 'lpipss_noisy': lpipss_noisy, 
          'tvs': tvs, 'best_step': best_step}
  

Finally, we define all the models we need to run the experiments.

In [None]:
img_siren = Siren(in_features=2, out_features=3, hidden_features=256, 
                     hidden_layers=3, outermost_linear=True)
img_siren.cuda()

img_siren_tv = Siren(in_features=2, out_features=3, hidden_features=256, 
                     hidden_layers=3, outermost_linear=True)
img_siren_tv.cuda()

img_siren_spe = PosEncSiren(in_features=2, out_features=3, hidden_features=256, 
                           hidden_layers=3)
img_siren_spe.cuda();

## Model 1: Fitting the noisy image directly

First, let's simply fit that noisy image!

We seek to parameterize a corrupted RGb image $\hat{I}(x)$ with pixel coordinates $x$ with a SIREN $\Phi(x)$.

That is we seek the function $\Phi$ such that:
$\mathcal{L}_{\text{fidelity}} = \int \Vert \Phi(\mathbf{x}) - \hat{I}(\mathbf{x})\Vert_2 d\mathbf{x}$
 is minimized, in which $\Omega$ is the domain of the image. 

During training, we visualize the fitted image output (`Model output`), the gradient of the fitted image (`Model output gradient`), the target noisy image that the model is fitting to (`Model target (noisy)`), and the underlying clean image that the model never sees (`Clean image`).

We also show the `PSNR_noisy`, the PSNR between the current image to the noisy target image, and `PSNR_clean`, the PSNR between the current image to the underlying clean image. We do the same for LPIPS.

In [None]:
total_steps=1200
steps_til_summary=20
train_out = train(img_siren, loss_type='fidelity', 
                  total_steps=total_steps, 
                  steps_til_summary=steps_til_summary)
psnrs_clean = train_out['psnrs_clean']
psnrs_noisy = train_out['psnrs_noisy']
lpipss_clean = train_out['lpipss_clean']
lpipss_noisy = train_out['lpipss_noisy']

Plot the trajectory of PSNR to clean image and PSNR to noisy image. We see that even though the model never sees the clean image and only fits to the noisy image, it is able to avoid fitting the noise have a high `PSNR_clean` in the first few hundred iterations, before eventually starting to fit the noise and decreasing the PSNR (starting at around iteration 500).

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(psnrs_clean, label="PSNR_clean (fidelity)")
plt.plot(psnrs_noisy, label="PSNR_noisy (fidelity)")
plt.plot(np.ones(len(psnrs_noisy)) * psnr_noisy_clean, label="PSNR of noisy to clean")
plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR for training with data fidelity")
plt.legend()
plt.savefig(os.path.join(output_dir, 'fidelity_PSNR.png'), dpi=300)
plt.show()

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(np.arange(0, total_steps, steps_til_summary), lpipss_clean, label="LPIPS_clean (fidelity)")
plt.plot(np.arange(0, total_steps, steps_til_summary), lpipss_noisy, label="LPIPS_noisy (fidelity)")
plt.plot(np.arange(0, total_steps, steps_til_summary), np.ones(len(lpipss_noisy)) * lpips_noisy_clean, label="LPIPS (alex) of noisy to clean")
plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS metric for training with data fidelity")
plt.legend()
plt.savefig(os.path.join(output_dir, 'fidelity_LPIPS.png'), dpi=300)
plt.show()

## Model 2: Fitting the noisy image with TV prior

Next, we add TV regularization for fitting the image. The new loss function becomes

$\mathcal{L}_{\text{tv}} = \int \Vert \Phi(\mathbf{x}) - \hat{I}(\mathbf{x})\Vert_2 + \kappa \Vert \nabla_{\mathbf{x}}\Phi(\mathbf{x})\Vert_1  d\mathbf{x}$,
where $\kappa$ is a hyperparameter for the TV regularization weight.

In [None]:
total_steps=2400
steps_til_summary=40
train_out_tv = train(img_siren_tv, loss_type='tv', k1=5e-4,
                  total_steps=total_steps, 
                  steps_til_summary=steps_til_summary,
                  early_stop=2000, lr=1e-3)
psnrs_clean_tv = train_out_tv['psnrs_clean']
psnrs_noisy_tv = train_out_tv['psnrs_noisy']
lpipss_clean_tv = train_out_tv['lpipss_clean']
lpipss_noisy_tv = train_out_tv['lpipss_noisy']
tvs = train_out_tv['tvs']

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(psnrs_clean_tv, label="PSNR_clean (TV)")
plt.plot(psnrs_noisy_tv, label="PSNR_noisy (TV)")

plt.plot(np.ones(len(psnrs_noisy_tv)) * psnr_noisy_clean, label="PSNR of noisy to clean")
plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR for traning with TV regularizer")
plt.legend()
plt.savefig(os.path.join(output_dir, 'tv_PSNR.png'), dpi=300)
plt.show()

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(np.arange(0, total_steps, steps_til_summary), lpipss_clean_tv[:-1], label="LPIPS_clean (TV)")
plt.plot(np.arange(0, total_steps, steps_til_summary), lpipss_noisy_tv[:-1], label="LPIPS_noisy (TV)")
plt.plot(np.arange(0, total_steps, steps_til_summary), np.ones(len(lpipss_noisy_tv[:-1])) * lpips_noisy_clean, label="LPIPS (alex) of noisy to clean")
plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS for training with TV regularizer")
plt.legend()
plt.savefig(os.path.join(output_dir, 'tv_LPIPS.png'), dpi=300)
plt.show()

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(tvs)
plt.xlabel("Iteration")
plt.ylabel("Value")
plt.title("Total variation")
plt.savefig(os.path.join(output_dir, 'tv.png'), dpi=300)
plt.show()

## Model 3: Fitting Noisy Image with Spline Coordinate Encoding
Finally, we encode the input coordinates with spline coordinate encoding before passing it to SIREN. The new loss function is $\mathcal{L}_{\text{spe}} = \int \Vert \Phi(S(\mathbf{x})) - \hat{I}(\mathbf{x})\Vert_2  d\mathbf{x}$, where $S$ is the spline coordinate projection function.

In [None]:
total_steps=2400
steps_til_summary=40
train_out_spe = train(img_siren_spe, loss_type='fidelity', k1=5e-5,
                  total_steps=total_steps, 
                  steps_til_summary=steps_til_summary, spe=True,
                  early_stop=2000, lr=1e-3)
psnrs_clean_spe = train_out_spe['psnrs_clean']
psnrs_noisy_spe = train_out_spe['psnrs_noisy']
lpipss_clean_spe = train_out_spe['lpipss_clean']
lpipss_noisy_spe = train_out_spe['lpipss_noisy']
tvs_spe = train_out_spe['tvs']

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(psnrs_clean_spe, label="PSNR_clean (SPE)")
plt.plot(psnrs_noisy_spe, label="PSNR_noisy (SPE)")

plt.plot(np.ones(len(psnrs_noisy_spe)) * psnr_noisy_clean, label="PSNR of noisy to clean")
plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR for traning with SPE")
plt.legend()
plt.savefig(os.path.join(output_dir, 'spe_PSNR.png'), dpi=300)
plt.show()

In [None]:
fig = plt.figure(1, dpi=150)
plt.plot(np.arange(0, total_steps, steps_til_summary), lpipss_clean_spe[:-1], label="LPIPS_clean (SPE)")
plt.plot(np.arange(0, total_steps, steps_til_summary), lpipss_noisy_spe[:-1], label="LPIPS_noisy (SPE)")
plt.plot(np.arange(0, total_steps, steps_til_summary), np.ones(len(lpipss_noisy_spe[:-1])) * lpips_noisy_clean, label="LPIPS (alex) of noisy to clean")
plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS for training with SPE")
plt.legend()
plt.savefig(os.path.join(output_dir, 'spe_LPIPS.png'), dpi=300)
plt.show()

## Summary of Training Curves

In [None]:
fig = plt.figure(1, dpi=150)

plt.plot(np.arange(0, 1200), psnrs_clean, label="PSNR_clean (fidelity)")
plt.plot(np.arange(0, 2400), psnrs_clean_tv[:-1], label="PSNR_clean (TV)")
psnrs_clean_spe_smoothed = np.convolve(psnrs_clean_spe[:-1], np.ones(20), 'valid') / 20
plt.plot(np.arange(0, 2400-19), psnrs_clean_spe_smoothed, label="PSNR_clean (SPE, smoothed)")

plt.xlabel("Iteration")
plt.ylabel("PSNR")
plt.title("PSNR comparison")
plt.legend()
plt.savefig(os.path.join(output_dir, 'comparison_PSNR.png'), dpi=300)
plt.show()

In [None]:
fig = plt.figure(1, dpi=150)

plt.plot(np.arange(0, 1200, 20), lpipss_clean, label="LPIPS_clean (fidelity)")
plt.plot(np.arange(0, 2400, 40), lpipss_clean_tv[:-1], label="LPIPS_clean (TV)")
plt.plot(np.arange(0, 2400, 40), lpipss_clean_spe[:-1], label="LPIPS_clean (SPE)")

plt.xlabel("Iteration")
plt.ylabel("LPIPS")
plt.title("LPIPS comparison")
plt.legend()
plt.savefig(os.path.join(output_dir, 'comparison_LPIPS.png'), dpi=300)
plt.show()

### Load Trained Models

In [None]:
fidelity_stop = 600
tv_stop = 2000
spe_stop = 2000

img_siren.load_state_dict(torch.load(os.path.join(output_dir, f'fidelity_step{fidelity_stop}.pt')))
img_siren_tv.load_state_dict(torch.load(os.path.join(output_dir, f'tv_step{tv_stop}.pt')))
img_siren_spe.load_state_dict(torch.load(os.path.join(output_dir, f'fidelity_spe_step{spe_stop}.pt')))

# Qualitative Comparison of SIREN Variants

In [None]:
fidelity_stop = 600
tv_stop = 2000
spe_stop = 2000

with torch.no_grad():
  # Get fidelity results
  img_siren.eval()
  model_output, _ = img_siren(model_input)
  img_out_fidelity = model_output.cpu().view(256, 256, -1).detach().numpy()
  fidelity_best_psnr = compare_psnr(img_clean, img_out_fidelity)
  fidelity_best_lpips = lpips_np(img_clean, img_out_fidelity)
  
  # Get tv results
  img_siren_tv.eval()
  model_output, _ = img_siren_tv(model_input)
  img_out_tv = model_output.cpu().view(256, 256, -1).detach().numpy()
  tv_best_psnr = compare_psnr(img_clean, img_out_tv)
  tv_best_lpips = lpips_np(img_clean, img_out_tv)

  # Get spe results
  img_siren_spe.eval()
  model_output, _ = img_siren_spe(model_input)
  img_out_spe = model_output.cpu().view(256, 256, -1).detach().numpy()
  spe_best_psnr = compare_psnr(img_clean, img_out_spe)
  spe_best_lpips = lpips_np(img_clean, img_out_spe)

In [None]:
print(f"fidelity_best_psnr: {fidelity_best_psnr}")
print(f"tv_best_psnr: {tv_best_psnr}")
print(f"spe_best_psnr: {spe_best_psnr}")
print(f"fidelity_best_lpips: {fidelity_best_lpips}")
print(f"tv_best_lpips: {tv_best_lpips}")
print(f"spe_best_lpips: {spe_best_lpips}")

In [None]:
bbox_caption = {'facecolor': 'white', 'pad': 8}
fig, axes = plt.subplots(1, 5, figsize=(30, 6))

axes[0].imshow(img_noisy)
axes[0].text(12, 20, f'Noisy image. PSNR_clean={round(psnr_noisy_clean, 3)}', bbox=bbox_caption)
axes[0].text(12, 50, f'LPIPS_clean={round(lpips_noisy_clean, 3)}', bbox=bbox_caption)

axes[1].imshow(img_out_fidelity)
axes[1].text(12, 20, f'Fidelity model. PSNR_clean={round(fidelity_best_psnr, 3)}', bbox=bbox_caption)
axes[1].text(12, 50, f'LPIPS_clean={round(fidelity_best_lpips, 3)}', bbox=bbox_caption)

axes[2].imshow(img_out_tv)
axes[2].text(12, 20, f'TV model. PSNR_clean={round(tv_best_psnr, 3)}', bbox=bbox_caption)
axes[2].text(12, 50, f'LPIPS_clean={round(tv_best_lpips, 3)}', bbox=bbox_caption)

axes[3].imshow(img_out_spe)
axes[3].text(12, 20, f'SPE model. PSNR_clean={round(spe_best_psnr, 3)}', bbox=bbox_caption)
axes[3].text(12, 50, f'LPIPS_clean={round(spe_best_lpips, 3)}', bbox=bbox_caption)

axes[4].imshow(img_clean)
axes[4].text(12, 20, 'Clean image', bbox=bbox_caption)

plt.savefig(os.path.join(output_dir, 'denoised_comparison.png'), dpi=300)
plt.show()

# Model 4: Continuous Bilateral Filtering on SIREN representation

Here we implement a continuous version of bilateral filtering on the learned SIREN model. See paper for details.

In [None]:
def siren_blur(model, n_average=25, scale=0.05, sigma_loc=0.05, sigma_int=0.05):
  model_input, _ = next(iter(dataloader))
  model_input = model_input.cuda()

  with torch.no_grad():
    # Calculate clean input PSNR
    model_output, _ = model(model_input)
    img_out_gt = model_output.cpu().view(256, 256, -1).detach().numpy()

    # Calculate multiple responses
    model_input = model_input.repeat(n_average, 1, 1)
    # noise = torch.normal(torch.zeros_like(model_input), torch.ones_like(model_input) * std)
    noise = (torch.rand(model_input.shape, device='cuda') - 0.5) * scale
    model_output, coords = model(model_input + noise)
    img_out_mult = model_output.cpu().view(n_average, 256, 256, -1).detach().cpu().numpy()

    # Calculate weight of each response
    location_weight = np.exp(-np.linalg.norm(noise.cpu().numpy(), axis=-1)**2 / (2 * sigma_loc**2))
    location_weight = location_weight.reshape(n_average, 256, 256, 1)
    intensity_weight = np.exp(-(img_out_mult - img_out_gt)**2 / (2 * sigma_int**2))
    weight = location_weight * intensity_weight
    weight /= weight.mean(0)
    
    final_img_out = (weight * img_out_mult).mean(0)

  return final_img_out

In [None]:
img_out_fidelity_blur = siren_blur(img_siren, scale=0.1, sigma_loc=0.06, sigma_int=0.06)

In [None]:
bbox_caption = {'facecolor': 'white', 'pad': 8}
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_out_fidelity)
axes[0].text(12, 20, f'Fidelity model. PSNR_clean={round(compare_psnr(img_clean, img_out_fidelity), 3)}', bbox=bbox_caption)
axes[0].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_fidelity), 3)}', bbox=bbox_caption)

axes[1].imshow(img_out_fidelity_blur)
axes[1].text(12, 20, f'Fidelity model blurred. PSNR_clean={round(compare_psnr(img_clean, img_out_fidelity_blur), 3)}', bbox=bbox_caption)
axes[1].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_fidelity_blur), 3)}', bbox=bbox_caption)

plt.savefig(os.path.join(output_dir, 'fidelity_blurred.png'), dpi=300)
plt.show()

In [None]:
img_out_tv_blur = siren_blur(img_siren_tv, scale=0.1, sigma_loc=0.06, sigma_int=0.06)

In [None]:
bbox_caption = {'facecolor': 'white', 'pad': 8}
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_out_tv)
axes[0].text(12, 20, f'TV model. PSNR_clean={round(compare_psnr(img_clean, img_out_tv), 3)}', bbox=bbox_caption)
axes[0].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_tv), 3)}', bbox=bbox_caption)

axes[1].imshow(img_out_tv_blur)
axes[1].text(12, 20, f'TV model blurred. PSNR_clean={round(compare_psnr(img_clean, img_out_tv_blur), 3)}', bbox=bbox_caption)
axes[1].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_tv_blur), 3)}', bbox=bbox_caption)

plt.savefig(os.path.join(output_dir, 'tv_blurred.png'), dpi=300)
plt.show()

Note: ideally I would run SPE as well but the free version of Colab cannot really handle it at a reasonale speed using CUDA without running out of memory.

# Baseline Comparisons

## Baseline 1: BM3D

In [None]:
!pip install bm3d
import bm3d

In [None]:
img_out_bm3d = bm3d.bm3d(img_noisy, sigma_psd=25/255).astype(np.float32)

In [None]:
bbox_caption = {'facecolor': 'white', 'pad': 8}
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_noisy)
axes[0].text(12, 20, f'Noisy image. PSNR_clean={round(compare_psnr(img_clean, img_noisy), 3)}', bbox=bbox_caption)
axes[0].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_noisy), 3)}', bbox=bbox_caption)

axes[1].imshow(img_out_bm3d)
axes[1].text(12, 20, f'BM3D model. PSNR_clean={round(compare_psnr(img_clean, img_out_bm3d), 3)}', bbox=bbox_caption)
axes[1].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_bm3d), 3)}', bbox=bbox_caption)

plt.savefig(os.path.join(output_dir, 'denoised_bm3d.png'), dpi=300)
plt.show()

## Baseline 2: Deep Image Prior

Here we directly load the results of trained Deep Image Prior. See their project page for a Colab notebook on how to train it yourself.

In [None]:
img_out_dip = (io.imread(os.path.join(output_dir, 'dip_out.jpg')) / 255).astype(np.float32)

bbox_caption = {'facecolor': 'white', 'pad': 8}
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_noisy)
axes[0].text(12, 20, f'Noisy image. PSNR_clean={round(compare_psnr(img_clean, img_noisy), 3)}', bbox=bbox_caption)
axes[0].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_noisy), 3)}', bbox=bbox_caption)

axes[1].imshow(img_out_dip)
axes[1].text(12, 20, f'Deep Image Prior model. PSNR_clean={round(compare_psnr(img_clean, img_out_dip), 3)}', bbox=bbox_caption)
axes[1].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_dip), 3)}', bbox=bbox_caption)

plt.savefig(os.path.join(output_dir, 'denoised_dip.png'), dpi=300)
plt.show()

## Baseline 3: Denoising Convolutional Neural Network (DnCNN)
Similarly, I load a DnCNN that is already trained.

In [None]:
from collections import OrderedDict

def sequential(*args):
    """Advanced nn.Sequential.

    Args:
        nn.Sequential, nn.Module

    Returns:
        nn.Sequential
    """
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)

# --------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# --------------------------------------------
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2):
    L = []
    for t in mode:
        if t == 'C':
            L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
        elif t == 'T':
            L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
        elif t == 'B':
            L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
        elif t == 'I':
            L.append(nn.InstanceNorm2d(out_channels, affine=True))
        elif t == 'R':
            L.append(nn.ReLU(inplace=True))
        elif t == 'r':
            L.append(nn.ReLU(inplace=False))
        elif t == 'L':
            L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
        elif t == 'l':
            L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
        elif t == '2':
            L.append(nn.PixelShuffle(upscale_factor=2))
        elif t == '3':
            L.append(nn.PixelShuffle(upscale_factor=3))
        elif t == '4':
            L.append(nn.PixelShuffle(upscale_factor=4))
        elif t == 'U':
            L.append(nn.Upsample(scale_factor=2, mode='nearest'))
        elif t == 'u':
            L.append(nn.Upsample(scale_factor=3, mode='nearest'))
        elif t == 'v':
            L.append(nn.Upsample(scale_factor=4, mode='nearest'))
        elif t == 'M':
            L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
        elif t == 'A':
            L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
        else:
            raise NotImplementedError('Undefined type: '.format(t))
    return sequential(*L)


class DnCNN(nn.Module):
    def __init__(self, in_nc=1, out_nc=1, nc=64, nb=17, act_mode='BR'):
        """
        # ------------------------------------
        in_nc: channel number of input
        out_nc: channel number of output
        nc: channel number
        nb: total number of conv layers
        act_mode: batch norm + activation function; 'BR' means BN+ReLU.
        # ------------------------------------
        Batch normalization and residual learning are
        beneficial to Gaussian denoising (especially
        for a single noise level).
        The residual of a noisy image corrupted by additive white
        Gaussian noise (AWGN) follows a constant
        Gaussian distribution which stablizes batch
        normalization during training.
        # ------------------------------------
        """
        super(DnCNN, self).__init__()
        assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
        bias = True

        m_head = conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
        m_body = [conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
        m_tail = conv(nc, out_nc, mode='C', bias=bias)

        self.model = sequential(m_head, *m_body, m_tail)

    def forward(self, x):
        n = self.model(x)
        return x-n

In [None]:
dncnn = DnCNN(in_nc=1, out_nc=1, nc=64, nb=17, act_mode='R')
dncnn.load_state_dict(torch.load(os.path.join(output_dir, 'dncnn_25.pth')), strict=True)
dncnn.eval()
with torch.no_grad():
  dncnn_out_channels = []
  for i in range(3):
    out = dncnn(torch.tensor(img_noisy[:, :, i]).unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0).cpu().numpy()
    dncnn_out_channels.append(out)

img_out_dncnn = np.stack(dncnn_out_channels, -1)

In [None]:
bbox_caption = {'facecolor': 'white', 'pad': 8}
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_noisy)
axes[0].text(12, 20, f'Noisy image. PSNR_clean={round(compare_psnr(img_clean, img_noisy), 3)}', bbox=bbox_caption)
axes[0].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_noisy), 3)}', bbox=bbox_caption)

axes[1].imshow(img_out_dncnn)
axes[1].text(12, 20, f'DnCNN model. PSNR_clean={round(compare_psnr(img_clean, img_out_dncnn), 3)}', bbox=bbox_caption)
axes[1].text(12, 50, f'LPIPS_clean={round(lpips_np(img_clean, img_out_dncnn), 3)}', bbox=bbox_caption)

plt.savefig(os.path.join(output_dir, 'denoised_dnc.png'), dpi=300)
plt.show()

# Quantitative Result Bar Plot
Note: I had to hard code the values because I messed saving them when first training.

In [None]:
psnrs_final = {
    'fidelity': 26.303,
    'tv': 21.243,
    'spe': 23.196,
    'fidelity_blur': 27.006,
    'dip': 28.002,
    'bm3d': 27.488,
    'dncnn': 27.964
}

lpipss_final = {
    'fidelity': 0.146,
    'tv': 0.239,
    'spe': 0.241,
    'fidelity_blur': 0.1022,
    'dip': 0.103,
    'bm3d': 0.176,
    'dncnn': 0.126
}

In [None]:
fig = plt.figure(1, dpi=150)
offset = 20.122
for i, key in enumerate(psnrs_final.keys()):
  plt.bar(key, psnrs_final[key] - offset, bottom=offset)
plt.xlabel("SIREN variants and baseline methods")
plt.ylabel("PSNR")
plt.title("Increase in PSNR after denoising")
plt.savefig(os.path.join(output_dir, 'psnr_bar.png'), dpi=300)
plt.show()

In [None]:
fig = plt.figure(1, dpi=150)
offset = 0.263
for i, key in enumerate(lpipss_final.keys()):
  plt.bar(key, lpipss_final[key] - offset, bottom=offset)
plt.xlabel("SIREN variants and baseline methods")
plt.ylabel("LPIPS")
plt.title("Decrease in LPIPS after denoising")
plt.savefig(os.path.join(output_dir, 'lpips_bar.png'), dpi=300)
plt.show()