In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt

import time

def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

In [2]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate
    
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords        

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

In [3]:
def get_cameraman_tensor(sidelength):
    img = Image.fromarray(skimage.data.camera())        
    transform = Compose([
        Resize(sidelength),
        ToTensor(),
        Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
    ])
    img = transform(img)
    return img

In [4]:
class ImageFitting(Dataset):
    def __init__(self, sidelength):
        super().__init__()
        img = get_cameraman_tensor(sidelength)
        self.pixels = img.permute(1, 2, 0).view(-1, 1)
        self.coords = get_mgrid(sidelength, 2)

    def __len__(self):
        return 1

    def __getitem__(self, idx):    
        if idx > 0: raise IndexError
            
        return self.coords, self.pixels

In [None]:
cameraman = ImageFitting(256)
dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)

img_siren = Siren(in_features=2, out_features=1, hidden_features=256, 
                  hidden_layers=3, outermost_linear=True)
img_siren.to('cuda')

In [None]:
total_steps = 500 # Since the whole image is our dataset, this just means 500 gradient descent steps.
steps_til_summary = 10

optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())

model_input, ground_truth = next(iter(dataloader))


for step in range(total_steps):
    model_output, coords = img_siren(model_input.to('cuda'))    
    loss = ((model_output - ground_truth.to('cuda'))**2).mean()

    if not step % steps_til_summary:
        print("Step %d, Total loss %0.6f" % (step, loss))

    optim.zero_grad()
    loss.backward()
    optim.step()

In [None]:
model_input.shape

In [None]:
!pip install 'ray[tune]'

In [17]:
from ray import tune

In [27]:
import ray

In [29]:
ray.init(num_gpus=1, ignore_reinit_error=True) 

2023-05-19 07:07:09,522	INFO worker.py:1454 -- Calling ray.init() again after it has already been called.


0,1
Python version:,3.10.11
Ray version:,2.4.0


In [21]:
def training_function(config):
  omega = config["omega"]
  cameraman = ImageFitting(256)
  dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)

  img_siren = Siren(in_features=2, out_features=1, hidden_features=256, 
                  hidden_layers=3, outermost_linear=True, first_omega_0 = omega, hidden_omega_0 = omega)
  img_siren.cuda()
  total_steps = 500 # Since the whole image is our dataset, this just means 500 gradient descent steps.
  steps_til_summary = 10

  optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())

  model_input, ground_truth = next(iter(dataloader))


  for step in range(total_steps):
      model_output, coords = img_siren(model_input.cuda())    
      loss = ((model_output - ground_truth.cuda())**2).mean()

      #if not step % steps_til_summary:
          #print("Step %d, Total loss %0.6f" % (step, loss))
      tune.report(loss = loss)
      optim.zero_grad()
      loss.backward()
      optim.step()

In [30]:
analysis = tune.run(training_function, config={"omega":tune.grid_search([10,20,30,40,50,60,70,80,90,100])}, resources_per_trial={"gpu": 1})

0,1
Current time:,2023-05-19 07:11:51
Running for:,00:04:19.05
Memory:,2.6/12.7 GiB

Trial name,status,loc,omega,iter,total time (s)
training_function_d34b4_00000,TERMINATED,172.28.0.12:19030,10,500,21.3742
training_function_d34b4_00001,TERMINATED,172.28.0.12:19220,20,500,21.4653
training_function_d34b4_00002,TERMINATED,172.28.0.12:19382,30,500,21.3115
training_function_d34b4_00003,TERMINATED,172.28.0.12:19537,40,500,21.1998
training_function_d34b4_00004,TERMINATED,172.28.0.12:19691,50,500,21.1731
training_function_d34b4_00005,TERMINATED,172.28.0.12:19847,60,500,21.0408
training_function_d34b4_00006,TERMINATED,172.28.0.12:20005,70,500,21.4797
training_function_d34b4_00007,TERMINATED,172.28.0.12:20162,80,500,22.1166
training_function_d34b4_00008,TERMINATED,172.28.0.12:20319,90,500,21.1922
training_function_d34b4_00009,TERMINATED,172.28.0.12:20478,100,500,21.2136


[2m[36m(training_function pid=19030)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined][32m [repeated 2x across cluster][0m


Trial name,date,done,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
training_function_d34b4_00000,2023-05-19_07-07-58,True,0_omega=10,bbfb37d3f579,500,0.00988233,172.28.0.12,19030,21.3742,0.0452771,21.3742,1684480078,500,d34b4_00000
training_function_d34b4_00001,2023-05-19_07-08-25,True,1_omega=20,bbfb37d3f579,500,0.00152368,172.28.0.12,19220,21.4653,0.0409653,21.4653,1684480105,500,d34b4_00001
training_function_d34b4_00002,2023-05-19_07-08-51,True,2_omega=30,bbfb37d3f579,500,0.00095058,172.28.0.12,19382,21.3115,0.0396395,21.3115,1684480131,500,d34b4_00002
training_function_d34b4_00003,2023-05-19_07-09-16,True,3_omega=40,bbfb37d3f579,500,0.000647309,172.28.0.12,19537,21.1998,0.0435688,21.1998,1684480156,500,d34b4_00003
training_function_d34b4_00004,2023-05-19_07-09-41,True,4_omega=50,bbfb37d3f579,500,0.000465181,172.28.0.12,19691,21.1731,0.0463526,21.1731,1684480181,500,d34b4_00004
training_function_d34b4_00005,2023-05-19_07-10-07,True,5_omega=60,bbfb37d3f579,500,0.000288714,172.28.0.12,19847,21.0408,0.0397301,21.0408,1684480207,500,d34b4_00005
training_function_d34b4_00006,2023-05-19_07-10-33,True,6_omega=70,bbfb37d3f579,500,0.000213284,172.28.0.12,20005,21.4797,0.0413172,21.4797,1684480233,500,d34b4_00006
training_function_d34b4_00007,2023-05-19_07-10-59,True,7_omega=80,bbfb37d3f579,500,0.000147477,172.28.0.12,20162,22.1166,0.0424807,22.1166,1684480259,500,d34b4_00007
training_function_d34b4_00008,2023-05-19_07-11-25,True,8_omega=90,bbfb37d3f579,500,0.000100838,172.28.0.12,20319,21.1922,0.0403044,21.1922,1684480285,500,d34b4_00008
training_function_d34b4_00009,2023-05-19_07-11-51,True,9_omega=100,bbfb37d3f579,500,6.1887e-05,172.28.0.12,20478,21.2136,0.0393956,21.2136,1684480311,500,d34b4_00009


[2m[36m(training_function pid=19220)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19220)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19382)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19537)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19537)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19691)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19847)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=19847)[0m   return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
[2m[36m(training_function pid=20005)[0m   return _VF.meshgrid(tensors, **kwargs)  # t

In [31]:
df = analysis.results_df
df

Unnamed: 0_level_0,loss,time_this_iter_s,done,training_iteration,date,timestamp,time_total_s,pid,hostname,node_ip,time_since_restore,iterations_since_restore,experiment_tag,config/omega
trial_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
d34b4_00000,"tensor(0.0099, device='cuda:0', requires_grad=...",0.045277,True,500,2023-05-19_07-07-58,1684480078,21.374188,19030,bbfb37d3f579,172.28.0.12,21.374188,500,0_omega=10,10
d34b4_00001,"tensor(0.0015, device='cuda:0', requires_grad=...",0.040965,True,500,2023-05-19_07-08-25,1684480105,21.465313,19220,bbfb37d3f579,172.28.0.12,21.465313,500,1_omega=20,20
d34b4_00002,"tensor(0.0010, device='cuda:0', requires_grad=...",0.039639,True,500,2023-05-19_07-08-51,1684480131,21.311458,19382,bbfb37d3f579,172.28.0.12,21.311458,500,2_omega=30,30
d34b4_00003,"tensor(0.0006, device='cuda:0', requires_grad=...",0.043569,True,500,2023-05-19_07-09-16,1684480156,21.199846,19537,bbfb37d3f579,172.28.0.12,21.199846,500,3_omega=40,40
d34b4_00004,"tensor(0.0005, device='cuda:0', requires_grad=...",0.046353,True,500,2023-05-19_07-09-41,1684480181,21.173064,19691,bbfb37d3f579,172.28.0.12,21.173064,500,4_omega=50,50
d34b4_00005,"tensor(0.0003, device='cuda:0', requires_grad=...",0.03973,True,500,2023-05-19_07-10-07,1684480207,21.040816,19847,bbfb37d3f579,172.28.0.12,21.040816,500,5_omega=60,60
d34b4_00006,"tensor(0.0002, device='cuda:0', requires_grad=...",0.041317,True,500,2023-05-19_07-10-33,1684480233,21.479664,20005,bbfb37d3f579,172.28.0.12,21.479664,500,6_omega=70,70
d34b4_00007,"tensor(0.0001, device='cuda:0', requires_grad=...",0.042481,True,500,2023-05-19_07-10-59,1684480259,22.116592,20162,bbfb37d3f579,172.28.0.12,22.116592,500,7_omega=80,80
d34b4_00008,"tensor(0.0001, device='cuda:0', requires_grad=...",0.040304,True,500,2023-05-19_07-11-25,1684480285,21.192205,20319,bbfb37d3f579,172.28.0.12,21.192205,500,8_omega=90,90
d34b4_00009,"tensor(6.1887e-05, device='cuda:0', requires_g...",0.039396,True,500,2023-05-19_07-11-51,1684480311,21.213622,20478,bbfb37d3f579,172.28.0.12,21.213622,500,9_omega=100,100


In [23]:
torch.cuda.is_available()

True