# Multiresolution Hash Encoding

In [None]:
from typing import List
import torch
import numpy as np
from multiresolution_hash_encoding import MultiresolutionHashEncoding
from PIL import Image
from matplotlib import pyplot as plt
import pickle
Image.MAX_IMAGE_PIXELS = None

## Train Images

In [None]:
EPOCHS = 1

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self, image_dims):
        super().__init__()
        self.encoding = MultiresolutionHashEncoding(2**22, 2, 2, levels=16, N_min=16, N_max=image_dims)
        self.mlp = torch.nn.ModuleList(
            [torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3),
            torch.nn.Sigmoid()]
        )

    def forward(self, x):
        x = self.encoding(x)
        for layer in self.mlp:
            x = layer(x)
        return x

In [None]:
img = Image.open("test_files/husk.jpg")

im_dim = max(img.width, img.height)

m = ImageEncoder(im_dim).cuda()

optimizer = torch.optim.Adam(
            [{'params': m.encoding.parameters()}, {'params':m.mlp.parameters(), 'weight_decay':1e-6}],
            betas=(0.9, 0.999),
            eps=1e-15,
            lr = 1e-2)

loss_func = torch.nn.MSELoss()

grid = np.mgrid[0:img.height, 0:img.width].T.reshape(-1, 2)
np.random.shuffle(grid)
BATCH_SIZE = 1000
grid:List[np.ndarray] = [grid[i * BATCH_SIZE:(i+1) * BATCH_SIZE] for i in range(int(np.ceil(len(grid) / BATCH_SIZE)))]

print(f"Image batch size ({img.width}, {img.height})")

np_img = np.array(img).astype(np.float32) / 255.0

for i in range(EPOCHS):
    for j, inp in enumerate(grid):
        im = torch.tensor(np_img[inp[:,0], inp[:,1]], device='cuda:0')

        optimizer.zero_grad()

        pred_im = m(torch.tensor(inp.astype(np.float32) / im_dim, device='cuda:0'))

        loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)
        
        print(f"Batch: {i + 1}/{EPOCHS}, Step: {j + 1}/{len(grid)}, loss:{loss.detach().cpu()}.")

        loss.backward()
        optimizer.step()

optimizer.zero_grad(True)


In [None]:
img = Image.open("test_files/tokyo.jpg")

im_dim = max(img.width, img.height)
print(f"Image batch size ({img.width}, {img.height})")

img = None
losses_list = []

#m = ImageEncoder(im_dim).cuda()

optimizer = torch.optim.Adam(
            m.parameters(),
            betas=(0.9, 0.999),
            eps=1e-15,
            lr = 1e-2)

loss_func = torch.nn.MSELoss()
BATCH_SIZE = 1_000_000

for e in range(EPOCHS):
    for j in range(609):
        with open(f"test_files/batches/batch_{j}.pickle", 'rb') as f:
            data = pickle.load(f)
            sub_grid:np.ndarray = data["pos"]
            colors:np.ndarray = data["color"].astype(np.float32)
            data = None
        
        for k in range(len(sub_grid) // BATCH_SIZE):
            grid = sub_grid[k*BATCH_SIZE:(k+1)*BATCH_SIZE]
            np_img = colors[k*BATCH_SIZE:(k+1)*BATCH_SIZE]
            grid = grid.astype(np.float32) / im_dim

            im = torch.tensor(np_img[:len(grid)], device='cuda:0')

            optimizer.zero_grad()

            pred_im = m(torch.tensor(grid, device='cuda:0'))

            loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)

            losses_list.append(loss.detach().cpu().numpy())
            
            print(f"Batch: {j + 1}/{609}, Step: {k + 1}/{len(sub_grid) // BATCH_SIZE}, loss:{loss.detach().cpu()}.")

            loss.backward()
            optimizer.step()
    torch.save(m.state_dict(), f"test_files/tokyo_epoch{e}.pt")
optimizer.zero_grad(True)
with open(f"test_files/loss_total.pickle", 'wb') as f:
    pickle.dump(np.array(losses_list), f)

In [None]:
plt.plot(np.array(losses_list))

In [None]:
start_x = 28_359
start_y = 19_000

end_x = 29_559
end_y = 20_000

with torch.no_grad():
    pred_im = m(torch.tensor(
                np.mgrid[start_x:end_x, start_y:end_y].T.reshape(-1,2), device='cuda:0'
            ) / im_dim )
    im_arr = (pred_im.cpu().numpy().reshape(end_y - start_y, end_x - start_x,3))
    plt.imshow(im_arr)
    plt.figure()
    plt.imshow(img.crop((start_x, start_y, end_x, end_y)))

In [None]:
torch.save(m.state_dict(), "test_files/tokyo.pt")

In [None]:
img = Image.open("test_files/tokyo.jpg")

im_dim = max(img.width, img.height)

width = img.width
height = img.height

img = None

m = ImageEncoder(im_dim).cuda()

m.load_state_dict(torch.load("test_files/tokyo_final_2.pt"))
m.eval()

In [None]:
MSE = 0
norm = width * height * 3
with torch.no_grad():
    for i in range(609):
        print(f"Processing batch {i}", end="\r")
        with open(f"test_files/batches/batch_{i}.pickle", 'rb') as f:
            data = pickle.load(f)
            sub_grid:np.ndarray = data["pos"]
            colors:np.ndarray = torch.from_numpy(data["color"].astype(np.float32)).cuda()
            data = None
        for j in range(2):
            pred_im:torch.Tensor = m(torch.from_numpy(sub_grid[j*(sub_grid.shape[0]//2):(j+1)*(sub_grid.shape[0]//2)]).cuda())
            MSE += (torch.square(pred_im - colors[j*(sub_grid.shape[0]//2):(j+1)*(sub_grid.shape[0]//2)]) / norm).sum().cpu().numpy()

PSNR = 10 * np.log10((1) / MSE)

print(f"MSE: {MSE}")
print(f"PSNR: {PSNR}")


## Create Batch Indices

In [None]:
BATCH_SIZE = 2_000_000

img = Image.open("test_files/tokyo.jpg")
print("Aranging indices...", end='\r')
indices = np.arange(img.width * img.height)
width = img.width
height = img.height
img = None
print("Shuffling indices...", end='\r')
np.random.shuffle(indices)

num_batches = int(np.ceil(len(indices) / BATCH_SIZE))

coords = np.empty((BATCH_SIZE, 2), dtype=np.int32)

for i in range(num_batches):
    print(f"Processing {i + 1} of {num_batches}.", end='\r')
    inds = indices[i * BATCH_SIZE:(i+1)*BATCH_SIZE]
    for j, index in enumerate(inds):
        coords[j, 0] = index // height
        coords[j, 1] = index % height

    with open(f"test_files/batch_indices/indices_{i}.pickle", 'wb') as f:
        pickle.dump(coords[:len(inds)], f)



In [None]:
import numpy as np
from PIL import Image
import pickle
Image.MAX_IMAGE_PIXELS = None

img = Image.open("test_files/tokyo.jpg")

for i in range(600,609):
    with open(f"test_files/batch_indices/indices_{i}.pickle", 'rb') as f:
        grid = pickle.load(f)
    np_img = np.empty((grid.shape[0], 3), np.uint8)
    for j, c in enumerate(grid):
        color = img.getpixel((c[0], c[1]))
        np_img[j, 0] = color[0]
        np_img[j, 1] = color[1]
        np_img[j, 2] = color[2]
    np_img = np_img / 255.0
    with open(f"test_files/batches/batch_{i}.pickle", 'wb') as f:
        pickle.dump({"pos":grid, "color":np_img}, f)
    print(f"Done processing file {i}.")