# Multiresolution Hash Encoding

In [None]:
from typing import List
import torch
from torch import nn
import numpy as np
from multiresolution_hash_encoding import MultiresolutionHashEncoding
from transformer import ContextTransformer
from PIL import Image
from matplotlib import pyplot as plt
import pickle
Image.MAX_IMAGE_PIXELS = None

## Train Images

In [None]:
EPOCHS = 10

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self, image_dims):
        super().__init__()
        self.encoding = MultiresolutionHashEncoding(2**22, 2, 2, levels=16, N_min=16, N_max=image_dims)
        self.mlp = torch.nn.ModuleList(
            [torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3),
            torch.nn.Sigmoid()]
        )

    def forward(self, x):
        x = self.encoding(x)
        for layer in self.mlp:
            x = layer(x)
        return x

def PSNR(img:np.ndarray, target:np.ndarray):
    mse = np.square(img - target).mean()
    psne = 10 * np.log10((1) / mse)
    return psne

@torch.no_grad()
def predict_full(m:ImageEncoder, img:Image) -> np.ndarray:
    pred_np_im = np.empty((img.height, img.width), dtype=np.float32)
    inputs = np.mgrid[0:img.height, 0:img.width].T.reshape(-1, 2)
    batch_size = int(np.ceil(inputs.shape[0] / 5))
    for i in range(5):
        indices = inputs[i * batch_size:(i+1) * batch_size]
        pred_im = m(torch.tensor(
                    indices, device='cuda:0'
                ) / im_dim )
        pred_np_im[indices[:,0], indices[:,1]] = pred_im.cpu().numpy().flatten()
    return pred_np_im
    

In [None]:
img = Image.open("test_files/albert.jpg")

im_dim = max(img.width, img.height)

m = ImageEncoder(im_dim).cuda()

optimizer = torch.optim.Adam(
            [{'params': m.encoding.parameters()}, {'params':m.mlp.parameters(), 'weight_decay':1e-6}],
            betas=(0.9, 0.999),
            eps=1e-15,
            lr = 1e-2)

loss_func = torch.nn.MSELoss()

grid = np.mgrid[0:img.height, 0:img.width].T.reshape(-1, 2)
#np.random.shuffle(grid)
BATCH_SIZE = 10_000
#grid:List[np.ndarray] = [grid[i * BATCH_SIZE:(i+1) * BATCH_SIZE] for i in range(int(np.ceil(len(grid) / BATCH_SIZE)))]
grid_l:List[np.ndarray] = []
for i in range(int(np.ceil(img.height / 100))):
     for j in range(int(np.ceil(img.width / 100))):
        print(f"{i}/{int(np.ceil(img.height / 100))}", end='\r')
        patch_inds = np.where((grid[:,0] >= i*100) & (grid[:,0] < (i+1)*100) & (grid[:,1] >= j*100) & (grid[:,1] < (j+1)*100))
        grid_l.append(grid[patch_inds])
grid=grid_l
print(f"Image batch size ({img.width}, {img.height})")

np_img = np.array(img).astype(np.float32) / 255.0

loss_list = []
psnr_list = []

for i in range(EPOCHS):
    for j, inp in enumerate(grid):
        im = torch.tensor(np_img[inp[:,0], inp[:,1]], device='cuda:0')

        optimizer.zero_grad()

        pred_im = m(torch.tensor(inp.astype(np.float32) / im_dim, device='cuda:0'))

        loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)
        
        print(f"Batch: {i + 1}/{EPOCHS}, Step: {j + 1}/{len(grid)}, loss:{loss.detach().cpu()}.")

        loss_list.append(loss.detach().cpu())

        loss.backward()
        optimizer.step()
    
    psnr = PSNR(predict_full(m, img), np_img)
    psnr_list.append(psnr)
    print(f"PSNR: {psnr}")

loss_list = np.array(loss_list)
psnr_list = np.array(psnr_list)

optimizer.zero_grad(True)


In [None]:
fig = plt.figure()
fig.suptitle(f'column wise\nmin {loss_list.min()}')
plt.plot(loss_list, label='MSE')
plt.legend()
with open(f"test_files/results/MSE_column_wise.png", 'wb') as f:
    fig.savefig(f)
fig = plt.figure()
plt.plot(psnr_list, label='window')
plt.plot([10.2,10.8,12.8,15.3,14.9,18.1,18.25,15.9], label='column')
plt.plot([11.25,10.9,11.15,11.5,10.6,10.4,11.2,11.22,11.26,9.46], label='row')
plt.legend()
with open(f"test_files/results/PSNR_column_row.png", 'wb') as f:
    fig.savefig(f)

In [None]:
plt.plot([10_000, 50_000, 100_000, 500_000, 1_000_000], [34.02,33.91,32.7, 32.18,30.2], label='PSNR')
plt.legend()
plt.xlabel('batch size')
plt.ylabel('PSNR')

In [None]:
img = Image.open("test_files/albert.jpg")

im_dim = max(img.width, img.height)

m = ImageEncoder(im_dim).cuda()
with open("test_files/test.pt", 'rb') as f:
    m.load_state_dict(torch.load(f))

In [None]:
pred_np_im = predict_full(m, img)
plt.imshow(pred_np_im, cmap='gray', vmin=0, vmax=1)
plt.axis('off')
plt.figure()
plt.axis('off')
plt.imshow(img, cmap='gray', vmin=0, vmax=255)

## Tokyo

In [None]:
img = Image.open("test_files/tokyo.jpg")

im_dim = max(img.width, img.height)
print(f"Image batch size ({img.width}, {img.height})")

img = None
losses_list = []

#m = ImageEncoder(im_dim).cuda()

optimizer = torch.optim.Adam(
            [{'params': m.encoding.parameters()}, {'params':m.mlp.parameters(), 'weight_decay':1e-6}],
            betas=(0.9, 0.999),
            eps=1e-15,
            lr = 1e-2)

loss_func = torch.nn.MSELoss()
BATCH_SIZE = 1_000_000

for e in range(EPOCHS):
    for j in range(609):
        with open(f"test_files/batches/batch_{j}.pickle", 'rb') as f:
            data = pickle.load(f)
            sub_grid:np.ndarray = data["pos"]
            colors:np.ndarray = data["color"].astype(np.float32)
            data = None
        
        for k in range(len(sub_grid) // BATCH_SIZE):
            grid = sub_grid[k*BATCH_SIZE:(k+1)*BATCH_SIZE]
            np_img = colors[k*BATCH_SIZE:(k+1)*BATCH_SIZE]
            grid = grid.astype(np.float32) / im_dim

            im = torch.tensor(np_img[:len(grid)], device='cuda:0')

            optimizer.zero_grad()

            pred_im = m(torch.tensor(grid, device='cuda:0'))

            loss:torch.Tensor = loss_func(pred_im.reshape(im.shape), im)

            losses_list.append(loss.detach().cpu().numpy())
            
            print(f"Epoch: {e + 1}/{EPOCHS}, Batch: {j + 1}/{609}, Step: {k + 1}/{len(sub_grid) // BATCH_SIZE}, loss:{loss.detach().cpu()}.", end='\r')

            loss.backward()
            optimizer.step()
    torch.save(m.state_dict(), f"test_files/tokyo_epoch{e}.pt")
    with open(f"test_files/loss_{e}.pickle", 'wb') as f:
        pickle.dump(np.array(losses_list), f)
optimizer.zero_grad(True)

In [None]:
plt.plot(np.array(losses_list))
with open(f"test_files/loss_total.pickle", 'wb') as f:
    pickle.dump(np.array(losses_list), f)

In [None]:
start_x = 28_000
start_y = 19_000

end_x = 30_000
end_y = 20_000

img = Image.open("test_files/tokyo.jpg")

with torch.no_grad():
    pred_im = m(torch.tensor(
                np.mgrid[start_x:end_x, start_y:end_y].T.reshape(-1,2), device='cuda:0'
            ) / im_dim )
    im_arr = (pred_im.cpu().numpy().reshape(end_y - start_y, end_x - start_x,3))
    plt.imshow(im_arr)
    plt.axis('off')
    plt.figure()
    plt.axis('off')
    plt.imshow(img.crop((start_x, start_y, end_x, end_y)))

print(f"PSNR: {PSNR(im_arr, np.array(img.crop((start_x, start_y, end_x, end_y)), dtype=np.float32) / 255.0)}")

In [None]:
img = Image.open("test_files/tokyo.jpg")

im_dim = max(img.width, img.height)

width = img.width
height = img.height

img = None

m = ImageEncoder(im_dim).cuda()

m.load_state_dict(torch.load("test_files/tokyo_epoch2.pt"))
m.eval()

In [None]:
MSE = 0
norm = width * height * 3
with torch.no_grad():
    for i in range(609):
        print(f"Processing batch {i}", end="\r")
        with open(f"test_files/batches/batch_{i}.pickle", 'rb') as f:
            data = pickle.load(f)
            sub_grid:np.ndarray = data["pos"]
            colors:np.ndarray = torch.from_numpy(data["color"].astype(np.float32)).cuda()
            data = None
        for j in range(2):
            pred_im:torch.Tensor = m(torch.from_numpy(sub_grid[j*(sub_grid.shape[0]//2):(j+1)*(sub_grid.shape[0]//2)]).cuda())
            MSE += (torch.square(pred_im - colors[j*(sub_grid.shape[0]//2):(j+1)*(sub_grid.shape[0]//2)]) / norm).sum().cpu().numpy()

PSNR = 10 * np.log10((1) / MSE)

print(f"MSE: {MSE}")
print(f"PSNR: {PSNR}")


One run: 13.09
Two runs: 12.98

### Create Batch Indices

In [None]:
BATCH_SIZE = 2_000_000

img = Image.open("test_files/tokyo.jpg")
print("Aranging indices...", end='\r')
indices = np.arange(img.width * img.height)
width = img.width
height = img.height
img = None
print("Shuffling indices...", end='\r')
np.random.shuffle(indices)

num_batches = int(np.ceil(len(indices) / BATCH_SIZE))

coords = np.empty((BATCH_SIZE, 2), dtype=np.int32)

for i in range(num_batches):
    print(f"Processing {i + 1} of {num_batches}.", end='\r')
    inds = indices[i * BATCH_SIZE:(i+1)*BATCH_SIZE]
    for j, index in enumerate(inds):
        coords[j, 0] = index // height
        coords[j, 1] = index % height

    with open(f"test_files/batch_indices/indices_{i}.pickle", 'wb') as f:
        pickle.dump(coords[:len(inds)], f)



In [None]:
import numpy as np
from PIL import Image
import pickle
Image.MAX_IMAGE_PIXELS = None

img = Image.open("test_files/tokyo.jpg")

for i in range(600,609):
    with open(f"test_files/batch_indices/indices_{i}.pickle", 'rb') as f:
        grid = pickle.load(f)
    np_img = np.empty((grid.shape[0], 3), np.uint8)
    for j, c in enumerate(grid):
        color = img.getpixel((c[0], c[1]))
        np_img[j, 0] = color[0]
        np_img[j, 1] = color[1]
        np_img[j, 2] = color[2]
    np_img = np_img / 255.0
    with open(f"test_files/batches/batch_{i}.pickle", 'wb') as f:
        pickle.dump({"pos":grid, "color":np_img}, f)
    print(f"Done processing file {i}.")

In [None]:
import numpy as np
from matplotlib import pyplot as plt

A = np.array([
    [0,0],
    [1,1],
    [4,2],
    [5,2],
    [3,3]
],dtype=float)

plt.scatter(A[:,0], A[:,1])

print(np.sqrt(np.square(A.reshape(5, 1, 2) - A.reshape(1, 5, 2)).sum(axis=2)))

## Various experiments

In [None]:
size = 2**7

enc = MultiresolutionHashEncoding(size, 2, 1, levels=8, N_min=16, N_max=100)
inputs = np.mgrid[0:100, 0:100].T.reshape(-1, 2)
inputs = torch.from_numpy(inputs) / 100

scaled_coords, grid_coords = enc._scale_to_grid(inputs)
hashes = enc._fast_hash(grid_coords)
print(hashes[..., 0, :])
inpts = hashes.clone().unsqueeze(dim=1)

outpts = enc._interpolate(scaled_coords, grid_coords, inpts)

outpts = outpts.reshape(100, 100, 2, 4)

fig, axes = plt.subplots(1,2)
axes[0].imshow(outpts[..., 0, 0].reshape(100, 100), vmin=0, vmax=size, cmap='gray')
axes[1].imshow(hashes[..., 0, 0].reshape(100, 100), vmin=0, vmax=size, cmap='gray')

fig, axes = plt.subplots(2, 4)

for i in range(2):
    for j in range(4):
        axes[i, j].imshow(outpts[:,:, i, j], vmin=0, vmax=size, cmap='gray')

# cRPE Attention

In [None]:
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
import torch_cluster

data = ShapeNet("data/ShapeNet")

data_ld = DataLoader(data)

In [None]:
import itertools

class EmbeddingNet(nn.Module):

    k: int
    encoder: MultiresolutionHashEncoding

    def __init__(self, k:int):
        super().__init__()
        self.encoder = MultiresolutionHashEncoding(2**12, 3, 4, levels=16,
                                        N_min=1,
                                        N_max=1e7)
        self.k = k

    def forward(self, xyz: torch.Tensor):
        knn_inds = torch_cluster.knn(xyz, xyz, self.k)[1].reshape(xyz.shape[0], self.k)
        rel_xyz = (xyz.unsqueeze(1) - xyz[knn_inds]).reshape(-1,3)
        feats = self.encoder(rel_xyz).reshape(xyz.shape[0], self.k, -1)
        return torch.max(feats, dim=1).values

class SegNetClassifier(nn.Module):
    classifier: nn.ModuleList

    def __init__(self, input_dim:int):
        super().__init__()
        self.classifier = nn.ModuleList([nn.LayerNorm(input_dim), 
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 50), nn.LogSoftmax(-1),
        ])

    def forward(self, feats):
        for layer in self.classifier:
            feats = layer(feats)
        return feats


class SegNet(nn.Module):

    embedding: EmbeddingNet
    transformer_stack: nn.ModuleList
    classifier: SegNetClassifier

    short_connections: int
    long_connections: int

    def __init__(self, k: int, short_connections:int, long_connections:int):
        super().__init__()
        self.short_connections = short_connections
        self.long_connections = long_connections
        self.embedding = EmbeddingNet(k)
        self.transformer_stack = nn.ModuleList(
            list(itertools.chain.from_iterable([
                [ContextTransformer(64 * (2**i), 2**(i+1), 1e7, 2**12), nn.LayerNorm(64 * (2**i)), nn.Linear(64 * (2**i), 64 * (2**(i+1)))] 
                for i in range(2)]))
        )
        self.classifier = SegNetClassifier(256)

    def prep_edges(self, xyz:torch.Tensor) -> torch.LongTensor:
        device = xyz.device
        edges = torch_cluster.knn(xyz, xyz, self.short_connections)[1].reshape(xyz.shape[0], self.short_connections)
        sub_sample_inds:torch.LongTensor = torch_cluster.fps(xyz, ratio=(self.long_connections + 1) / xyz.shape[0])
        long_edges = sub_sample_inds.unsqueeze(0).expand(xyz.shape[0], *sub_sample_inds.shape).clone()
        mask = (long_edges == torch.arange(xyz.shape[0], device=device).reshape(-1, 1))
        long_edges[mask] = long_edges[mask.any(dim=1), -1]
        long_edges = long_edges[:,:-1]

        return torch.cat((edges, long_edges), dim=-1)

    def forward(self, xyz:torch.Tensor):
        #Prepare long and short edges
        edges = self.prep_edges(xyz)

        feats = self.embedding(xyz)
        for layer in self.transformer_stack:
            if type(layer) is ContextTransformer:
                feats = layer(xyz, feats, edges)
            else:
                feats = layer(feats)
        feats = self.classifier(feats)
        return feats

m = SegNet(10, 20, 10).cuda()

In [None]:
EPOCHS = 1
BATCH_SIZE = 20

optimizer = torch.optim.Adam(m.parameters())

loss_func = torch.nn.NLLLoss()

m.train()

loss_list = []

for i in range(EPOCHS):
    m_loss = 0.0
    cts = 0
    for j, batch in enumerate(data_ld):
        xyz, y = batch.pos.cuda(), batch.y.cuda()

        optimizer.zero_grad()

        logits = m(xyz)

        loss:torch.Tensor = loss_func(logits, y)
        
        m_loss += loss.detach()
        
        loss.backward()
        cts += 1
        if (j + 1) % BATCH_SIZE == 0:
            optimizer.step()
            m_loss = m_loss / cts
            loss_list.append(m_loss)
            print(f"Epoch: {i + 1}/{EPOCHS}, Samples: {j + 1}/{len(data_ld)}, loss:{m_loss}.")
            m_loss = 0.0
            cts = 0
    
    if cts != 0:
        optimizer.step()

loss_list = np.array(loss_list)

optimizer.zero_grad(True)