In [1]:
%matplotlib inline

from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
from tqdm.notebook import tqdm as tqdm
from torchvision import transforms

import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch import utils
import torch.nn as nn
import numpy as np
import os, imageio
import pickle
import torch
import cv2

### Dataloader

In [2]:
mnist_trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0, 1.0),
    transforms.CenterCrop((28))
])

In [3]:
mnist= datasets.MNIST(root='./../data', train=True,download=False, transform=mnist_trans)
mnist = torch.utils.data.random_split(mnist, [10000, len(mnist)-10000])[0] # gets split into two parts [10k and 50k] and we slice 0 index.


### Device

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('Using PyTorch version:', torch.__version__, ' Device:', device)

# torch.set_default_dtype(torch.float64)
torch.set_default_device(device)

Using PyTorch version: 2.1.0.dev20230622+cu121  Device: cuda


In [5]:
torch.set_default_dtype(torch.float64)
torch.set_default_device(device)

### Converting Dataset

In [6]:
def generate_pixel_coordinates(img):
    
    height, width, _ = img.permute(1,2,0).shape
    
    # Generate coordinates along the x-axis and y-axis
    x_coords = np.linspace(0, 1, width, endpoint=False)
    y_coords = np.linspace(0, 1, height, endpoint=False)
    
    # Create a meshgrid of coordinates
    x_mesh, y_mesh = np.meshgrid(x_coords, y_coords)
    
    # Stack the coordinates and reshape to obtain the final output
    coordinates = np.stack([x_mesh, y_mesh], axis=-1)
    
    return coordinates

### Fourier Feature mapping


In [7]:
def input_mapping(x, B):
    '''
        cos and sin of input are joined together to increased the dimension to 4.
    '''
    if B is None:
        return x
    else:
        B = B.to(x)
        x_proj = (2.*np.pi*x) @ B.T #512,2
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)    #  2*len(B) #512,4

In [8]:
def batch_loader(coord_input,image,map_dict,mapping = 'none'):

    coord_input = input_mapping(coord_input, map_dict[mapping])
    
    test_data = [coord_input, image]
    train_data = [coord_input[::2], image[::2]]
    
    train_x = torch.tensor(train_data[0])#.reshape(-1,coord_input.shape[2]) # because input has 4 dimension
    train_y = torch.tensor(train_data[1])#.reshape(-1,3)
    test_x = torch.tensor(test_data[0])#.reshape(-1,coord_input.shape[2]) # because input has 4 dimension
    test_y = torch.tensor(test_data[1])#.reshape(-1,3)

    return train_x,train_y,test_x,test_y

### MLP network

In [9]:
class Net(nn.Module):
    def __init__(self,input_dim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_dim, 50) # input has 4 dimensions.
        self.fc1_drop = nn.Dropout(0.2)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(50, 50)
        self.fc2_drop = nn.Dropout(0.2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(50,1)

    def forward(self, x):
        # x = x.view(-1,4) # input has 4 dimensions.
        x = self.relu1(self.fc1(x))
        x = self.fc1_drop(x)
        x = self.relu2(self.fc2(x))
        x = self.fc2_drop(x)
        return self.fc3(x)

## Train

In [10]:
# train_x,train_y,test_x,test_y = batch_loader(xy_grid,x.permute(1,2,0),B_dict,keys)

In [11]:
# input = test_x.to(device)
# target = test_y.to(device)

# model = Net(input.shape[2]).to(device)
# optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-2)


# for epoch in tqdm(range(1501)):
#     optimizer.zero_grad()

#     generated = model(input)

#     loss = torch.nn.functional.l1_loss(target, generated)

#     loss.backward()
#     optimizer.step()

#     if epoch % 100 == 0:
#       print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
#       plt.imshow(generated.detach().cpu().numpy(),cmap='gray')
#       plt.show()

In [12]:
def model_weights(coordinates,image,img_no,dictionary,mapping,save = False):

    train_x,train_y,test_x,test_y = batch_loader(coordinates,image.permute(1,2,0),dictionary,mapping)

    input = test_x.to(device)
    target = test_y.to(device)

    model = Net(input.shape[2]).to(device)
    optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-2)


    for epoch in range(1501):
        optimizer.zero_grad()

        generated = model(input)

        loss = torch.nn.functional.l1_loss(target, generated)

        loss.backward()
        optimizer.step()

        # if epoch % 300 == 0:
        #     print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
        #     plt.imshow(generated.detach().cpu().numpy(),cmap='gray')
        #     plt.show()

    # if save:
    #     torch.save(model.state_dict(), '../weights/{fname}.pth'.format(fname=img_no))
    
    torch.cuda.empty_cache()
    del model
    

In [13]:
mapping_size = 64

B_dict = {}

B_gauss = torch.normal(0,1,size=(mapping_size,2))
for scale in [10.]:
    B_dict[f'gauss_{scale}'] = B_gauss * scale

for k in B_dict:
    keys = k

In [14]:
x,_ = mnist[0]

xy_grid = torch.from_numpy(generate_pixel_coordinates(x))

In [None]:
for index in tqdm(range(len(mnist))):

    model_weights(xy_grid,mnist[index][0],index,B_dict,keys,save=True)

  0%|          | 0/10000 [00:00<?, ?it/s]

  return func(*args, **kwargs)


### Load model

In [18]:
model = Net(128)

model.load_state_dict(torch.load('../weights/2.pth'))

for param in model.parameters():
  print(param.data)

tensor([[-0.0795,  0.1682,  0.0875,  ..., -0.2525, -0.0463,  0.0028],
        [ 0.0552, -0.0240,  0.0436,  ..., -0.0965, -0.0113,  0.1088],
        [ 0.1314, -0.1306,  0.1767,  ..., -0.0286, -0.2119,  0.1687],
        ...,
        [-0.0419,  0.0082,  0.2348,  ..., -0.3616, -0.2333, -0.0336],
        [-0.0057,  0.0810, -0.1398,  ...,  0.1748,  0.3189, -0.1856],
        [ 0.0761,  0.0039,  0.0316,  ...,  0.1862,  0.0608,  0.1784]],
       device='cuda:0')
tensor([-0.0222, -0.2944, -0.0153, -0.1830, -0.1238, -0.7707, -0.4240,  0.0094,
        -0.4637, -0.4999, -0.0846, -0.8748,  0.2038, -0.7559,  0.0421,  0.3459,
        -0.3898, -0.9280,  0.0213, -0.9460, -0.0888, -0.0765, -0.0522,  0.0309,
        -0.7213, -0.4241, -0.1429, -0.5051,  0.0449, -0.2816, -0.7010, -0.0519,
        -0.2336, -0.1436, -0.0351, -0.7080,  0.2282, -0.4984,  0.1681, -0.6942,
        -0.1610, -0.4047, -0.6369, -0.4612, -0.3822, -0.5614, -0.1953, -0.1232,
        -0.1155, -0.6370], device='cuda:0')
tensor([[-3.4537e-