# Multiresolution Hash Encoding

In [None]:
from typing import List
import torch
from torch import nn
import numpy as np
from multiresolution_hash_encoding import MultiresolutionHashEncoding
from PIL import Image
from matplotlib import pyplot as plt
import pickle
Image.MAX_IMAGE_PIXELS = None

## Train Images

In [None]:
EPOCHS = 1
BATCH_SIZE = 10_000
SAFE_MODEL = True
SAFE_FINAL_IMAGE = True

# These settings create a gif for the first N batches during the first epoch
# These values should work with 16GB RAM on the Einstein image
MAKE_TRAINING_GIF = True
MAX_GIF_IMAGES = 100 # This can take a lot of memory
GIF_FPS = 15
GIF_SCALE = 0.2

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self, image_dims, use_color=False):
        super().__init__()
        self.encoding = MultiresolutionHashEncoding(2**14, 2, 2, levels=16, N_min=16, N_max=image_dims)
        self.mlp = nn.ModuleList(
            [nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 3 if use_color else 1),
            nn.Sigmoid()]
        )

    def forward(self, x):
        x = self.encoding(x)
        for layer in self.mlp:
            x = layer(x)
        return x

def PSNR(img:np.ndarray, target:np.ndarray):
    mse = np.square(img - target).mean()
    psne = 10 * np.log10((1) / mse)
    return psne

@torch.no_grad()
def predict_full(m:ImageEncoder, img:Image) -> np.ndarray:
    pred_np_im = np.empty((img.height, img.width), dtype=np.float32)
    inputs = np.mgrid[0:img.height, 0:img.width].T.reshape(-1, 2)
    batch_size = int(np.ceil(inputs.shape[0] / 5))
    for i in range(5):
        indices = inputs[i * batch_size:(i+1) * batch_size]
        pred_im = m(torch.tensor(
                    indices, device='cuda:0'
                ) / im_dim )
        pred_np_im[indices[:,0], indices[:,1]] = pred_im.cpu().numpy().flatten()
    return pred_np_im
    

In [None]:
img = Image.open("test_files/albert.jpg")

im_dim = max(img.width, img.height)

m = ImageEncoder(im_dim).cuda()

optimizer = torch.optim.Adam(
            [{'params': m.encoding.parameters()}, {'params':m.mlp.parameters(), 'weight_decay':1e-6}],
            betas=(0.9, 0.999),
            eps=1e-15,
            lr = 1e-2)

loss_func = torch.nn.MSELoss()

grid = np.mgrid[0:img.height, 0:img.width].T.reshape(-1, 2)
np.random.shuffle(grid)
grid:List[np.ndarray] = [grid[i * BATCH_SIZE:(i+1) * BATCH_SIZE] for i in range(int(np.ceil(len(grid) / BATCH_SIZE)))]
print(f"Image batch size ({img.width}, {img.height})")

np_img = np.array(img).astype(np.float32) / 255.0

loss_list = []
psnr_list = []
image_list = []

for i in range(EPOCHS):
    for j, inp in enumerate(grid):
        im = torch.tensor(np_img[inp[:,0], inp[:,1]], device='cuda:0')

        optimizer.zero_grad()

        pred_im = m(torch.tensor(inp.astype(np.float32) / im_dim, device='cuda:0'))

        loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)
        
        print(f"Epoch: {i + 1}/{EPOCHS}, Batch: {j + 1}/{len(grid)}, loss:{loss.detach().cpu():.4f}.", end='\r')

        loss_list.append(loss.detach().cpu())
        if MAKE_TRAINING_GIF and j < MAX_GIF_IMAGES and i == 0:
            image_list.append(predict_full(m, img))

        loss.backward()
        optimizer.step()
    
    psnr = PSNR(predict_full(m, img), np_img)
    psnr_list.append(psnr)
    print()
    print(f"PSNR: {psnr}")

loss_list = np.array(loss_list)
psnr_list = np.array(psnr_list)

plt.plot(loss_list)
plt.xlabel("Batch")
plt.ylabel("MSE")
plt.figure()
plt.plot(psnr_list)
plt.xlabel("Epoch")
plt.ylabel("PSNE")

if MAKE_TRAINING_GIF:
    imgs = [
        Image.fromarray((t_img * 255).astype(np.uint8)).resize(
                tuple(int(t_img.shape[1-i] * GIF_SCALE) 
                for i in range(2))
            ) 
        for t_img in image_list
    ]
    # duration is the number of milliseconds between frames
    imgs[0].save("albert.gif", save_all=True, append_images=imgs[1:], duration=1000 // GIF_FPS, loop=0)

optimizer.zero_grad(True)

reconstructed_img = predict_full(m, img)

fig, axes = plt.subplots(1, 2)

axes[0].imshow(img, cmap='gray')
axes[0].set_title("Original")
axes[0].axis('off')

axes[1].imshow(reconstructed_img, cmap='gray')
axes[1].set_title("Reconstructed")
axes[1].axis('off')

if SAFE_MODEL:
    torch.save(m.state_dict(), "albert.pt")

if SAFE_FINAL_IMAGE:
    with open("reconstructed.png", 'wb') as f:
        Image.fromarray((reconstructed_img*255).astype(np.uint8)).save(f)
