# Multiresolution Hash Encoding

In [None]:
import torch
import numpy as np
from multiresolution_hash_encoding import MultiresolutionHashEncoding
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [None]:
EPOCHS = 1

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoding = MultiresolutionHashEncoding(2**22, 2, 2, levels=16, N_min=16, N_max=56718)
        self.mlp = torch.nn.ModuleList(
            [torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 3),
            torch.nn.Sigmoid(),]
        )
        self.mlp

    def forward(self, x):
        x = self.encoding(x)
        for layer in self.mlp:
            x = layer(x)
        return x

img = Image.open("test_files/tokyo.jpg")

m = ImageEncoder().cuda()

optimizer = torch.optim.Adam(
            m.parameters(),
            betas=(0.9, 0.999),
            eps=1e-08)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

loss_func = torch.nn.MSELoss()

im_x_split = 200
im_y_split = 800

sub_x_dim = img.size[0] // im_x_split
sub_y_dim = img.size[1] // im_y_split

print(f"Image batch size ({sub_x_dim}, {sub_y_dim})")

for i in range(EPOCHS):
    for j, k in [(j, k) for j in range(im_x_split) for k in range(im_y_split)]:
        print(f"{i * k} out of {im_x_split * im_y_split} batches.")
        im = np.array(img.crop((j * sub_x_dim, k * sub_y_dim, (j + 1) * sub_x_dim, (k + 1) * sub_y_dim)))
        im = torch.tensor(im.astype(np.float32) / 255.0, device='cuda:0')

        im_coords = torch.tensor(
            np.mgrid[j*sub_x_dim:(j + 1) * sub_x_dim, k * sub_y_dim:(k + 1) *sub_y_dim].reshape(2,-1).T, device='cuda:0'
        )

        optimizer.zero_grad()
        m.train()

        pred_im = m(im_coords)

        loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)

        loss.backward()
        optimizer.step()
    
    scheduler.step()

