# INR Demo
 This notebook will provide a brief demonstration of how our INR implementations can be used. In all five experiments we will fit an INR to an image from the Cifar10 dataset. The model hyperparameters are stored in separate .yaml files which are loaded in using `omegaconf`.

In [None]:
import sys
if not "./flexconv" in sys.path:
    sys.path.insert(0, './flexconv')
import torch
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchviz import make_dot
import numpy as np
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
import time

from INR import regularize_gabornet
from INR import MLP
from INR import Gabor
from INR.utils import cifar_grid, coordinate_grid

The follwing cell can be used to determine which experiments will be run.

In [None]:
RUN_RELU_MLP_INR = False
RUN_SINE_MLP_INR1 = False
RUN_SINE_MLP_INR2 = False

RUN_SIREN = True

RUN_GABOR1 = True
RUN_GABOR2 = True

IMAGE_IDX = 0

We will be using the Cifar10 dataset.

In [None]:
training_set = CIFAR10(transform=ToTensor(), root="data", download=True)

The following image will be used to train all INRs on.

In [None]:
training_img = training_set.__getitem__(IMAGE_IDX)
plt.imshow(training_img[0].permute((1,2,0)))
plt.axis('off')
plt.show()

In [None]:
input_img = training_img[0].permute((1,2,0))

# Used for calculating the PSNR
MAX_I_CHANNEL = input_img.max().item()
MAX_I = input_img.mean(2).max().item()

# Load in all model hyperparameters
cfg_ReLU = OmegaConf.load('./configs/RELU_MLP_INR_cfg.yaml')
cfg_sine1 = OmegaConf.load('./configs/RELU_SINE_INR1_cfg.yaml')
cfg_sine2 = OmegaConf.load('./configs/RELU_SINE_INR2_cfg.yaml')

This notebook does not use the `coordinate_grid` function, since the `cifar_grid` function works well enough for Cifar. The `coordinate_grid` function takes a list or tensor of ranges (e.g. `[[-1,1], [0,10]]`) and a list of sizes (i.e. the number of samples per dimension) and creates a coordinate grid using the cartesion product of linspaces.

In [None]:
domain = [[0,10], [-10,10], [-10, 0]]
size = [10, 20, 10]

# A demo of the coordinate_grid function
print(coordinate_grid(domain, size).size())

# Note that setting reshape to False will collapse all domain dimensions into one
print(coordinate_grid(domain, size, reshape=False).size())

# By default the size is set to make unit steps:
domain = [[-3,3],[0,1]]
print(coordinate_grid(domain))

Here the first few experiments are run. All models used here are MLPs (either ReLU MLPs or Sirens).

In [None]:
# We will be using the same coordinates for all models
train_coordinates = cifar_grid(32)

# Setup the first model and its optimization scheme
relu_MLP_INR = MLP(**cfg_ReLU)
optimizer = torch.optim.Adam(relu_MLP_INR.model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss(reduction='sum')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8)

# Run the first experiment
if RUN_RELU_MLP_INR:
    t_relu_start = time.time()
    epochs = [1500]*5
    relu_losses = relu_MLP_INR.fit(input_img, optimizer, criterion, scheduler, epochs, image_grid=train_coordinates)
    t_relu_end = time.time()
    print("Time ReLU:", t_relu_end-t_relu_start)
    plt.plot(relu_losses)

    

# Setup the second model and its optimization scheme
sine_MLP1_INR = MLP(**cfg_sine1)
optimizer = torch.optim.Adam(sine_MLP1_INR.model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss(reduction='sum')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8)

# Run the second experiment
if RUN_SINE_MLP_INR1:
    t_sine1_start = time.time()
    epochs = [5]*200
    sine1_losses = sine_MLP1_INR.fit(input_img, optimizer, criterion, scheduler, epochs, image_grid=train_coordinates)
    t_sine1_end = time.time()
    print("Time Sine 1:", t_sine1_end-t_sine1_start)
    plt.plot(sine1_losses)


    
# Setup the third model and its optimization scheme.
sine_MLP2_INR = MLP(**cfg_sine2)
optimizer = torch.optim.Adam(sine_MLP2_INR.model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss(reduction='sum')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.01)

# Run the third experiment
if RUN_SINE_MLP_INR2:
    t_sine2_start = time.time()
    epochs = [200]*5
    sine2_losses = sine_MLP2_INR.fit(input_img, optimizer, criterion, scheduler, epochs, image_grid=train_coordinates)
    t_sine2_end = time.time()
    print("Time Sine 2:", t_sine2_end-t_sine2_start)
    plt.plot(sine2_losses)

plt.show()

In [None]:
# Visualize the learned image representations

with torch.no_grad():
    coordinates1 = cifar_grid(32)
    coordinates2 = cifar_grid(64)
    coordinates3 = cifar_grid(128)
    coordinates4 = cifar_grid(256)
    
    out1_relu = relu_MLP_INR(coordinates1)
    out2_relu = relu_MLP_INR(coordinates2)
    out3_relu = relu_MLP_INR(coordinates3)
    out4_relu = relu_MLP_INR(coordinates4)
    
    out1_sine1 = sine_MLP1_INR(coordinates1)
    out2_sine1 = sine_MLP1_INR(coordinates2)
    out3_sine1 = sine_MLP1_INR(coordinates3)
    out4_sine1 = sine_MLP1_INR(coordinates4)
    
    out1_sine2 = sine_MLP2_INR(coordinates1)
    out2_sine2 = sine_MLP2_INR(coordinates2)
    out3_sine2 = sine_MLP2_INR(coordinates3)
    out4_sine2 = sine_MLP2_INR(coordinates4)
    
    loss_relu = criterion(out1_relu, input_img).item()
    loss_sine1 = criterion(out1_sine1, input_img).item()
    loss_sine2 = criterion(out1_sine2, input_img).item()
    
    print("MSE ReLU:", loss_relu)
    print("PSNR ReLU:", 20 * np.log10(MAX_I_CHANNEL) - 10 * np.log10(loss_relu))
    print()
    
    print("MSE sine1:", loss_sine1)
    print("PSNR sine1:", 20 * np.log10(MAX_I_CHANNEL) - 10 * np.log10(loss_sine1))
    print()
    
    print("MSE sine2:", loss_sine2)
    print("PSNR sine2:", 20 * np.log10(MAX_I_CHANNEL) - 10 * np.log10(loss_sine2))
    print()
    
    fig, ax = plt.subplots(3,5,figsize=(20,8))
    
    ax[0][0].imshow(input_img)
    ax[0][0].axis('off')
    ax[0][1].imshow(out1_relu)
    ax[0][1].axis('off')
    ax[0][2].imshow(out2_relu)
    ax[0][2].axis('off')
    ax[0][3].imshow(out3_relu)
    ax[0][3].axis('off')
    ax[0][4].imshow(out4_relu)
    ax[0][4].axis('off')
    
    ax[1][0].imshow(input_img)
    ax[1][0].axis('off')
    ax[1][1].imshow(out1_sine1)
    ax[1][1].axis('off')
    ax[1][2].imshow(out2_sine1)
    ax[1][2].axis('off')
    ax[1][3].imshow(out3_sine1)
    ax[1][3].axis('off')
    ax[1][4].imshow(out4_sine1)
    ax[1][4].axis('off')
    
    ax[2][0].imshow(input_img)
    ax[2][0].axis('off')
    ax[2][1].imshow(out1_sine2)
    ax[2][1].axis('off')
    ax[2][2].imshow(out2_sine2)
    ax[2][2].axis('off')
    ax[2][3].imshow(out3_sine2)
    ax[2][3].axis('off')
    ax[2][4].imshow(out4_sine2)
    ax[2][4].axis('off')

The fourth experiment (below) uses a multiplicative filter network (MFN) with Gabor filters as the model.

In [None]:
# These same coordinates are used as before
train_coordinates = cifar_grid(32)

# Setup the fourth model and its optimization scheme
cfg_gabor1 = OmegaConf.load('./configs/Gabor_INR_cfg1.yaml')
Gabor1 = Gabor(**cfg_gabor1.net)
epochs = [300]*1
optimizer = torch.optim.Adam(Gabor1.parameters(), lr=cfg_gabor1.train.lr)
criterion = torch.nn.MSELoss(reduction='sum')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9)

# Run the fourth experiment
if RUN_GABOR1:
    t_gabor_start = time.time()
    gabor_losses = Gabor1.fit(input_img, optimizer, criterion, scheduler, epochs, image_grid=train_coordinates)
    t_gabor_end = time.time()
    print("Time Gabor:", t_gabor_end-t_gabor_start)
    plt.plot(gabor_losses)

plt.show()

In [None]:
with torch.no_grad():
    coordinates1 = cifar_grid(32)
    coordinates2 = cifar_grid(64)
    coordinates3 = cifar_grid(128)
    coordinates4 = cifar_grid(256)
    
    out1_gabor = Gabor1(coordinates1)
    out2_gabor = Gabor1(coordinates2)
    out3_gabor = Gabor1(coordinates3)
    out4_gabor = Gabor1(coordinates4)
    
    loss_gabor1 = criterion(out1_gabor, input_img).item()
    
    print("MSE gabor:", loss_gabor1)
    print("PSNR sine2:", 20 * np.log10(MAX_I_CHANNEL) - 10 * np.log10(loss_gabor1))
    
    fig, ax = plt.subplots(1,5,figsize=(20,8))
    
    ax[0].imshow(input_img)
    ax[0].axis('off')
    ax[1].imshow(out1_gabor.reshape((32, 32, 3)))
    ax[1].axis('off')
    ax[2].imshow(out2_gabor.reshape((64, 64, 3)))
    ax[2].axis('off')
    ax[3].imshow(out3_gabor.reshape((128, 128, 3)))
    ax[3].axis('off')
    ax[4].imshow(out4_gabor.reshape((256, 256, 3)))
    ax[4].axis('off')

The cell below visualizes the backwards graph of a Gabornet with regularization.

In [None]:
# Visualize the graph
Gabor_visualize = Gabor(**cfg_gabor1.net)
Gabor_visualize.train()
criterion = torch.nn.MSELoss(reduction='sum')
out = Gabor_visualize(coordinates1)
loss = criterion(out, input_img)
gabor_reg = regularize_gabornet(gabor_net=Gabor_visualize.model, **cfg_gabor1.regularize_params)
make_dot(loss + gabor_reg, params=dict(Gabor_visualize.named_parameters()))

In the cell below the fifth and final experiment is run. This experiment again uses an MFN with Gabor filters.

In [None]:
# Setup the fourth model and its optimization scheme
cfg_gabor2 = OmegaConf.load('./configs/Gabor_INR_cfg2.yaml')
Gabor2 = Gabor(**cfg_gabor2.net)
criterion = torch.nn.MSELoss(reduction='sum')
params = set(Gabor2.model.parameters())
for gabor_filter in Gabor2.model.filters:
    params.union(set(gabor_filter.parameters()))
optimizer = torch.optim.Adam(params, lr=cfg_gabor2.train.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7)
epochs = [100]*10

# Run the fourth experiment
if RUN_GABOR2:
    Gabor2.fit(input_img, optimizer, criterion, scheduler, epochs, image_grid=train_coordinates, regularize=True,
               **cfg_gabor2.regularize_params)

In [None]:
with torch.no_grad():
    out1_gabor = Gabor2(coordinates1)
    out2_gabor = Gabor2(coordinates2)
    out3_gabor = Gabor2(coordinates3)
    out4_gabor = Gabor2(coordinates4)
    
    loss_gabor2 = criterion(out1_gabor, input_img).item()
    
    print("MSE gabor2:", loss_gabor2)
    print("PSNR sine2:", 20 * np.log10(MAX_I_CHANNEL) - 10 * np.log10(loss_gabor2))
    
    fig, ax = plt.subplots(1,5,figsize=(20,8))
    
    ax[0].imshow(input_img)
    ax[0].axis('off')
    ax[1].imshow(out1_gabor.reshape((32, 32, 3)))
    ax[1].axis('off')
    ax[2].imshow(out2_gabor.reshape((64, 64, 3)))
    ax[2].axis('off')
    ax[3].imshow(out3_gabor.reshape((128, 128, 3)))
    ax[3].axis('off')
    ax[4].imshow(out4_gabor.reshape((256, 256, 3)))
    ax[4].axis('off')