In [1]:
import torch
import torch.nn as nn

In [2]:
import numpy as np
import os

In [3]:
import matplotlib
from pylab import *
import skimage as sk
import skimage.io as io
from skimage.color import rgb2gray
from scipy.signal import convolve2d
import scipy.signal as signal

In [4]:
from torch.utils.data import Dataset, DataLoader

In [5]:
fox = io.imread("images/fox.jpg")

In [6]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.linear1 = nn.Linear(42, 256)
    self.activation = torch.nn.ReLU()
    self.linear2 = nn.Linear(256, 256)
    self.linear3 = nn.Linear(256, 3)
    self.sigmoid =  nn.Sigmoid()


  def forward(self, x):
    x = self.positionEncoder(x, 10)
    x = self.linear1(x)
    x = self.activation(x)
    x = self.linear2(x)
    x = self.activation(x)
    x = self.linear2(x)
    x = self.activation(x)
    x = self.linear3(x)
    x = self.sigmoid(x)
    return x

  def positionEncoder(self, positions, L):
    num_samples, channel = positions.shape
    positions_unsqueezed = torch.tensor(positions, dtype=torch.float32).unsqueeze(-1)
    frequencies = torch.arange(L, dtype=torch.float32).view(1, 1, -1)
    sin_values = torch.sin(2.0 ** frequencies * positions_unsqueezed * np.pi)
    cos_values = torch.cos(2.0 ** frequencies * positions_unsqueezed * np.pi)
    tensor1 = sin_values.view(-1)
    tensor2 = cos_values.view(-1)
    merged_tensor = []
    for i in range(0, len(tensor1),2):
        merged_tensor.append(tensor1[i])
        merged_tensor.append(tensor1[i+1])
        merged_tensor.append(tensor2[i])
        merged_tensor.append(tensor2[i+1])
    merged_tensor = torch.tensor(merged_tensor)
    encoding_vector = merged_tensor.view(num_samples, -1)
    encoding_vector = torch.cat((torch.tensor(positions,dtype=torch.float32), encoding_vector), dim=-1)
    return encoding_vector 

In [8]:
def positionEncoder(positions, L):
    num_samples, channel = positions.shape
    positions = torch.tensor(positions, dtype=torch.float32).unsqueeze(-1)
    frequencies = torch.arange(L, dtype=torch.float32).view(1, 1, -1)
    sin_values = torch.sin(2.0 ** frequencies * positions * np.pi)
    cos_values = torch.cos(2.0 ** frequencies * positions * np.pi)
    encoding_vector = torch.cat((sin_values, cos_values), dim=-1).view(2, -1)
    encoding_vector = torch.cat((positions.view(2, -1), encoding_vector), dim=-1)
    return encoding_vector

In [None]:
test_points = np.array([[1, 2], [3, 4]])
test_pe = positionEncoder(test_points, 2)
test_expected = np.array([[1, 2, np.sin(1 * np.pi * 1), np.sin(1 * np.pi * 2), np.cos(1 * np.pi * 1), np.cos(1 * np.pi * 2), np.sin(2 * np.pi * 1), np.sin(2 * np.pi * 2), np.cos(2 * np.pi * 1), np.cos(2 * np.pi * 2)], 
                            [3, 4, np.sin(1 * np.pi * 3), np.sin(1 * np.pi * 4), np.cos(1 * np.pi * 3), np.cos(1 * np.pi * 4), np.sin(2 * np.pi * 3), np.sin(2 * np.pi * 4), np.cos(2 * np.pi * 3), np.cos(2 * np.pi * 4)]]
                        )
np.isclose(np.array(test_pe), test_expected, atol= 1e-6)

In [None]:
from torch.utils.data import Dataset
from PIL import Image
class Dataloader(Dataset):
  def __init__(self, folder_path):
    self.image_list = []
    for filename in os.listdir(folder_path):
      file_path = os.path.join(folder_path, filename)
    if os.path.isfile(file_path) and filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')):
        image = Image.open(file_path)
        self.image_list.append(image)

  def getTrainingSet():
    return

  def getTestingSet():
    return

  def __len__(self) -> int:
    return len(self.image_list)

In [107]:
import random
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, TensorDataset

In [315]:
def sampleImage(img, N):
  assert len(img.shape) == 3
  assert len(img.shape) == 3
  if len(img) * len(img[0]) <= N:
    return img
  image_width, image_height = img.shape[1], img.shape[0]
  random_x = np.random.randint(0, image_width, size=N)
  random_y = np.random.randint(0, image_height, size=N)
  coords = np.column_stack((random_x, random_y))
  colors = np.array([img[y][x] for x, y in coords])
  coords_normalized = np.array(coords / np.array([image_width, image_height]))
  colors_normalized = np.array(colors / 255.0)
  assert len(coords_normalized) == len(colors_normalized)
  return [torch.tensor(coords_normalized , dtype=torch.float32), torch.tensor(colors_normalized , dtype=torch.float32)]

In [316]:
def PSNR(value):
  return 10 *  np.log(1/value)

Training

In [7]:
import matplotlib
from pylab import *
import skimage as sk
import skimage.io as io
from skimage.color import rgb2gray
from scipy.signal import convolve2d
import scipy.signal as signal

In [105]:
model = Model()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
fox_tensor = torch.tensor(fox).to(torch.float32) 

In [None]:
! pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html

In [304]:
def PE(positions, L):
    num_samples, channel = positions.shape
    positions_unsqueezed = torch.tensor(positions, dtype=torch.float32).unsqueeze(-1)
    frequencies = torch.arange(L, dtype=torch.float32).view(1, 1, -1)
    sin_values = torch.sin(2.0 ** frequencies * positions_unsqueezed * np.pi)
    cos_values = torch.cos(2.0 ** frequencies * positions_unsqueezed * np.pi)
    tensor1 = sin_values.view(-1)
    tensor2 = cos_values.view(-1)
    merged_tensor = []
    for i in range(0, len(tensor1),2):
        merged_tensor.append(tensor1[i])
        merged_tensor.append(tensor1[i+1])
        merged_tensor.append(tensor2[i])
        merged_tensor.append(tensor2[i+1])
    merged_tensor = torch.tensor(merged_tensor)
    encoding_vector = merged_tensor.view(num_samples, -1)
    encoding_vector = torch.cat((torch.tensor(positions,dtype=torch.float32), encoding_vector), dim=-1)
    return encoding_vector 

In [None]:
test_points = np.array([[1, 2], [3, 4], [5, 6]])
test_pe_3 = PE(test_points, 3)
test_pe_2 = PE(test_points, 2)
test_expected_3 = np.array([[1, 2, np.sin(1 * np.pi * 1), np.sin(1 * np.pi * 2), np.cos(1 * np.pi * 1), np.cos(1 * np.pi * 2), np.sin(2 * np.pi * 1), np.sin(2 * np.pi * 2), np.cos(2 * np.pi * 1), np.cos(2 * np.pi * 2), np.sin(4 * np.pi * 1), np.sin(4 * np.pi * 2), np.cos(4 * np.pi * 1), np.cos(4 * np.pi * 2)], 
                            [3, 4, np.sin(1 * np.pi * 3), np.sin(1 * np.pi * 4), np.cos(1 * np.pi * 3), np.cos(1 * np.pi * 4), np.sin(2 * np.pi * 3), np.sin(2 * np.pi * 4), np.cos(2 * np.pi * 3), np.cos(2 * np.pi * 4), np.sin(4 * np.pi * 3), np.sin(4 * np.pi * 4), np.cos(4 * np.pi * 3), np.cos(4 * np.pi * 4)], 
                             [5, 6, np.sin(1 * np.pi * 5), np.sin(1 * np.pi * 6), np.cos(1 * np.pi * 5), np.cos(1 * np.pi * 6), np.sin(2 * np.pi * 5), np.sin(2 * np.pi * 6), np.cos(2 * np.pi * 5), np.cos(2 * np.pi * 6), np.sin(4 * np.pi * 5), np.sin(4 * np.pi * 6), np.cos(4 * np.pi * 5), np.cos(4 * np.pi * 6)]
                         ]
                        )
test_expected_2 = np.array([[1, 2, np.sin(1 * np.pi * 1), np.sin(1 * np.pi * 2), np.cos(1 * np.pi * 1), np.cos(1 * np.pi * 2), np.sin(2 * np.pi * 1), np.sin(2 * np.pi * 2), np.cos(2 * np.pi * 1), np.cos(2 * np.pi * 2)], 
                            [3, 4, np.sin(1 * np.pi * 3), np.sin(1 * np.pi * 4), np.cos(1 * np.pi * 3), np.cos(1 * np.pi * 4), np.sin(2 * np.pi * 3), np.sin(2 * np.pi * 4), np.cos(2 * np.pi * 3), np.cos(2 * np.pi * 4)], 
                             [5, 6, np.sin(1 * np.pi * 5), np.sin(1 * np.pi * 6), np.cos(1 * np.pi * 5), np.cos(1 * np.pi * 6), np.sin(2 * np.pi * 5), np.sin(2 * np.pi * 6), np.cos(2 * np.pi * 5), np.cos(2 * np.pi * 6)]
                         ]
                        )

np.isclose(np.array(test_pe_3), test_expected_3, atol= 1e-6)

In [313]:
sample_coords = sampleImage(fox, 10000)[0]
sample_colors = sampleImage(fox, 10000)[1]

In [None]:
def normalize_image(im):
    return (im - np.amin(im)) / (np.amax(im) - np.amin(im))

In [None]:
with torch.no_grad():
    model.eval()
    y, x = np.meshgrid(np.linspace(0, 1, 689), np.linspace(0, 1, 1024))
    coords = np.dstack((y, x)).transpose(1, 0, 2)
    coords_tensor = torch.tensor(coord, dtype=torch.float32)
    outputs = model(coords_tensor)
    print(outputs.shape)