# Multiresolution Hash Encoding

In [None]:
from typing import List
import torch
import numpy as np
from multiresolution_hash_encoding import MultiresolutionHashEncoding
from PIL import Image
from matplotlib import pyplot as plt
Image.MAX_IMAGE_PIXELS = None

In [None]:
EPOCHS = 1

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self, image_dims):
        super().__init__()
        self.encoding = MultiresolutionHashEncoding(2**22, 2, 2, levels=16, N_min=16, N_max=image_dims)
        self.mlp = torch.nn.ModuleList(
            [torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3),
            torch.nn.Sigmoid()]
        )
        self.mlp

    def forward(self, x):
        x = self.encoding(x)
        for layer in self.mlp:
            x = layer(x)
        return x

In [None]:
img = Image.open("test_files/matthias-reim.jpg")

im_dim = max(img.width, img.height)

m = ImageEncoder(im_dim).cuda()

optimizer = torch.optim.Adam(
            m.parameters(),
            betas=(0.9, 0.999),
            eps=1e-15,
            lr = 1e-2)

loss_func = torch.nn.MSELoss()

grid = np.mgrid[0:img.height, 0:img.width].T.reshape(-1, 2)
np.random.shuffle(grid)
BATCH_SIZE = 5000
normalization = max(img.width, img.height)
grid:List[np.ndarray] = [grid[i * BATCH_SIZE:(i+1) * BATCH_SIZE] for i in range(int(np.ceil(len(grid) / BATCH_SIZE)))]

print(f"Image batch size ({img.width}, {img.height})")

np_img = np.array(img).astype(np.float32) / 255.0

for i in range(EPOCHS):
    for j, inp in enumerate(grid):
        im = torch.tensor(np_img[inp[:,0], inp[:,1]], device='cuda:0')

        optimizer.zero_grad()

        pred_im = m(torch.tensor(inp.astype(np.float32) / normalization, device='cuda:0'))

        loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)
        
        print(f"Batch: {i + 1}/{EPOCHS}, Step: {j + 1}/{len(grid)}, loss:{loss.detach().cpu()}.")

        loss.backward()
        optimizer.step()


In [None]:
with torch.no_grad():
    pred_im = m(torch.tensor(
                np.mgrid[0:img.height, 0:img.width].reshape(2,-1).T, device='cuda:0'
            ) / im_dim )
    im_arr = (pred_im.cpu().numpy().reshape(img.height, img.width,3))
    plt.imshow(im_arr)
    plt.figure()
    plt.imshow(img)