In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import e3nn
from e3nn import o3, io
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import open3d as o3d
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io

# plotly.io.renderers.default = "notebook"

from tqdm.notebook import tqdm
from functools import partial, reduce


from utils import load_model, save_model, CustomLRScheduler

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu' # for now batch=1 so no need for gpu
device


In [2]:
from data_generation import SimpleShapeDataset, SimpleShapeUniformRayDataset

dataset = SimpleShapeUniformRayDataset(sample_points=20000)

data = next(iter(DataLoader(dataset, batch_size=1))).float()

In [9]:
# how does it look like in spherical harmonics?

sphten = io.SphericalTensor(11, 1, 1)
signal = o3.spherical_harmonics(sphten, x=data, normalize=False).mean(dim=[0, 1])

fig = make_subplots(rows=1, cols=2, specs=[[{'is_3d': True} for j in range(2)] for i in range(1)])
# fig.add_trace(go.Scatter3d(x=0.12 * data[0, :, 0], y=0.12 * data[0, :, 1], z=0.12 * data[0, :, 2]), row=1, col=2)
fig.add_trace(go.Surface(sphten.plotly_surface(signal, radius=True)[0]), row=1, col=2)
fig.add_trace(go.Scatter3d(x=data[0, :, 0], y=data[0, :, 1], z=data[0, :, 2]), row=1, col=1)

In [4]:
class SelfTP(nn.Module):
    def __init__(self, irreps_in, irreps_out, batch_norm=True):
        super(SelfTP, self).__init__()
        
        self.tp = o3.FullyConnectedTensorProduct(irreps_in1=irreps_in, irreps_in2=irreps_in, irreps_out=irreps_out)
        self.bn = e3nn.nn.BatchNorm(irreps_out)
        self.batch_norm = batch_norm
        
    def forward(self, x):
        x = self.tp(x, x)
        if self.batch_norm:
            x = self.bn(x)
        return x      


class EncoderDecoder(nn.Module):
    def __init__(self, lmax=8, max_channel=64):
        super(EncoderDecoder, self).__init__()
       
        self.irrep_layer = []
        tw = 1
        for l in range(lmax, 0, -1): # last one sould be lmax=1
            irreps = io.SphericalTensor(lmax=l, p_arg=1, p_val=1)
            irreps = irreps * tw
            irreps = (irreps.sort())[0]
            irreps = irreps.simplify()
            self.irrep_layer.append(irreps)
            tw *= 2
            tw = min(tw, max_channel) # todo tune the max channel later

        self.model_sphten_repr = io.SphericalTensor(lmax=lmax, p_arg=1, p_val=1)
        self.latent_repr = self.irrep_layer[-1]
        
        # todo should the last layer have batch norm?
        # todo use fully connected tensor product?
        self.encoder = nn.Sequential(*[
            SelfTP(irreps_in=self.irrep_layer[i], irreps_out=self.irrep_layer[i+1], batch_norm=(i != len(self.irrep_layer)-2))
            for i in range(len(self.irrep_layer)-1)
        ])

        # todo is the information bottleneck limited enough?
        self.decoder = nn.Sequential(*[
            SelfTP(irreps_in=self.irrep_layer[i+1], irreps_out=self.irrep_layer[i], batch_norm=(i != 0))
            for i in range(len(self.irrep_layer)-2, -1, -1)            
        ])

    # def forward(self, points):
    #     inp = o3.spherical_harmonics(l=self.model_sphten_repr, x=points, normalize=False).mean(dim=1) # todo fix the normalization later
    #     latent = self.encoder(inp)
    #     out = self.decoder(latent)
    #     return inp, latent, out

    def forward(self, inp):
        # inp = o3.spherical_harmonics(l=self.model_sphten_repr, x=points, normalize=False).mean(dim=1) # todo fix the normalization later
        latent = self.encoder(inp)
        out = self.decoder(latent)
        return inp, latent, out

In [5]:
model = EncoderDecoder().to(device)

print("number of parameters = ", sum([np.prod(x.shape) for x in model.parameters()]))

dataset = SimpleShapeUniformRayDataset(sample_points=20000)

In [6]:
shs = []

for i in range(len(dataset)):
    points = torch.from_numpy(dataset[i])
    sh = o3.spherical_harmonics(l=model.model_sphten_repr, x=points, normalize=False).mean(dim=0)
    shs.append(sh)

In [7]:
def train(
        model,
        dataset,
        loss_fn,        
        epochs=600,
        checkpoint_interval=50,
        initial_rl=1,
        scheduler=None,
        optimizer=None
):
    
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)

    dataloader = DataLoader(dataset, batch_size=len(dataset))

    pbar = tqdm(range(epochs))
    losses = []
    for epoch in pbar:
        for data in dataloader:
            data = data.float().to(device)            
            inp, latent, out = model(data)    
            loss = loss_fn(inp, out)
            optimizer.zero_grad()
            loss.backward()
            max_grad = max([torch.linalg.norm(p.grad).item() for p in model.parameters() if p.grad is not None])
            optimizer.step()
            pbar.set_description(f'Epoch {i+1} Loss: {loss.item():.8f} Lr={optimizer.param_groups[0]["lr"]} MaxGrad={max_grad:.4f}')
            losses.append(loss.item())
            
        if epoch % checkpoint_interval == 0:
            name = f'epoch_{epoch}'
            save_model(model, name)
            print(f'saved {name}')

        if scheduler is not None:   
            scheduler.step()
    return losses

In [10]:
save_model(model, 'initial_state')
# load_model(model, 'sample')

In [12]:
initial_rl = 1
optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)
custom_scheduler = CustomLRScheduler(optimizer, initial_rl)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200, 300, 400, 500], gamma=1.1) # todo interesting apparently you shouldn't change rl
loss_fn = torch.nn.MSELoss(reduction='sum')

In [18]:
losses = train(model=model, dataset=shs, loss_fn=loss_fn, epochs=100, initial_rl=None, scheduler=custom_scheduler, optimizer=optimizer)
save_model(model, 'final_state')

plt.plot(losses)

In [49]:
load_model(model, 'initial_state')

initial_rl = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)
custom_scheduler = CustomLRScheduler(optimizer, initial_rl)

losses = train(model=model, dataset=shs, loss_fn=loss_fn, epochs=100, initial_rl=None, scheduler=custom_scheduler, optimizer=optimizer)
save_model(model, 'final_state')

plt.plot(losses)

In [54]:
load_model(model, 'final_state')
save_model(model, 'almost_perfect_state')

In [60]:
# optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)
# custom_scheduler = CustomLRScheduler(optimizer, initial_rl)

custom_scheduler.set_rl(0.01)
losses = train(model=model, dataset=shs, loss_fn=loss_fn, epochs=100, initial_rl=None, scheduler=custom_scheduler, optimizer=optimizer)
save_model(model, 'final_state')

plt.plot(losses)

In [13]:
load_model(model, 'almost_perfect_state')

dataloader = DataLoader(shs, batch_size=len(shs)) # todo note that this is shs 

single_data = next(iter(dataloader)).float().to(device)

inp, latent, out = model(single_data)
inp = inp.cpu().detach()
latent = latent.cpu().detach()
out = out.cpu().detach()

print('loss', loss_fn(inp, out))
print('L-inf loss', torch.max(torch.abs(inp - out)))

sphten = model.model_sphten_repr
fig = make_subplots(rows=1, cols=2, specs=[[{'is_3d': True} for j in range(2)] for i in range(1)])
fig.add_trace(go.Surface(sphten.plotly_surface(out, radius=True)[0]), row=1, col=1)
fig.add_trace(go.Surface(sphten.plotly_surface(inp, radius=True)[0]), row=1, col=2)

In [12]:
inp

In [13]:
out

In [14]:
latent

### Test Rotation

In [15]:
pc1 = dataset[0]

In [20]:
# todo later try with a different object

pc2 = np.einsum('ij,nj->ni', o3.rand_matrix(), dataset[0])

In [21]:
fig = make_subplots(rows=1, cols=2, specs=[[{'is_3d': True} for j in range(2)] for i in range(1)])
fig.add_trace(go.Scatter3d(x=pc1[:, 0], y=pc1[:, 1], z=pc1[:, 2]), row=1, col=1)
fig.add_trace(go.Scatter3d(x=pc2[:, 0], y=pc2[:, 1], z=pc2[:, 2]), row=1, col=2)

In [22]:
load_model(model, 'almost_perfect_state')
model.eval()

In [23]:
sh1 = o3.spherical_harmonics(l=model.model_sphten_repr, x=torch.from_numpy(pc1).float().unsqueeze(0), normalize=False).mean(dim=1)
sh2 = o3.spherical_harmonics(l=model.model_sphten_repr, x=torch.from_numpy(pc2).float().unsqueeze(0), normalize=False).mean(dim=1)

with torch.no_grad():
    inp1, latent1, out1 = model(sh1)
    inp2, latent2, out2 = model(sh2)

def interpolate_in_latent_space(repr, latent1, latent2, s):
    res = torch.empty_like(latent1)
    assert(latent1.shape == latent2.shape)
    
    ind = 0
    for l in repr.ls:
        sz = 2 * l + 1
        vec1 = latent1[..., ind:ind+sz]
        vec2 = latent2[..., ind:ind+sz]
        
        if sz == 1:
            # todo we do linear interpolation because the signs might be different... why do we not have to care about this in other irreps?
            vec_interp = vec1 + (vec2 - vec1) * s
            res[..., ind:ind+sz] = vec_interp
        elif sz == 3:
            norm1 = torch.linalg.norm(vec1, dim=-1, keepdim=True)
            norm2 = torch.linalg.norm(vec2, dim=-1, keepdim=True)
            vec1_norm = vec1 / norm1
            vec2_norm = vec2 / norm2
            axis = torch.cross(vec1_norm, vec2_norm, dim=-1)
            axis_norm = torch.norm(axis, dim=-1)
            axis /= axis_norm
            theta = torch.asin(axis_norm)
            R = o3.axis_angle_to_matrix(axis=axis, angle=theta * s)
            
            vec_interp_size = vec1 * torch.pow(norm2 / norm1, s)
            vec_interp_rotated = torch.einsum('nij,nj->ni', R, vec_interp_size)
            vec_interp = vec_interp_rotated
            res[..., ind:ind+sz] = vec_interp
        else:
            raise Exception("l > 1 is not supported in interpolation sorry :))")

        ind += sz
    assert(ind == res.shape[-1])
    return res


# interpolate_in_latent_space(repr=o3.Irreps('1e') + o3.Irrep('0e'), latent1=torch.tensor([10, 0, 0, 0]).float().unsqueeze(0), latent2=torch.tensor([1, -1, 0, 10]).float().unsqueeze(0), s=1)
# torch.max(torch.abs(interpolate_in_latent_space(repr=model.latent_repr, latent1=latent1, latent2=latent2, s=1) - latent2))

In [25]:
N = 3
cnt = 0
rows = 1
columns = 4

fig = make_subplots(rows=rows, cols=columns, specs=[[{'is_3d': True} for j in range(columns)] for i in range(rows)])

for i in range(rows):
    for j in range(columns):
        if cnt >= N:
            continue
        latent_interp = interpolate_in_latent_space(repr=model.latent_repr, latent1=latent1, latent2=latent2, s=cnt/N)
        with torch.no_grad():
            out_interp = model.decoder(latent_interp).detach().cpu()
        sphten = model.model_sphten_repr
        fig.add_trace(go.Surface(sphten.plotly_surface(out_interp, radius=True)[0]), row=i+1, col=j+1)
        cnt += 1

fig.add_trace(go.Surface(sphten.plotly_surface(out_interp, radius=True)[0]), row=1, col=4)