In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import e3nn
from e3nn import o3, io
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import open3d as o3d
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io

# plotly.io.renderers.default = "notebook"

from tqdm.notebook import tqdm
from functools import partial, reduce


from utils import load_model, save_model, CustomLRScheduler
from data_generation import SimpleShapeDataset, SimpleShapeUniformRayDataset, BoxesDataset, SimpleShapeGridDataset
from model import EncoderDecoder
from convolution import S2ConvNet_Autoencoder
from visualize import visualize_points, visualize_signal
from losses import GridLoss, WeightedGridLoss, WeightedPointLoss, WeightedGridLossWithRotation

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu' # for now batch=1 so no need for gpu
device

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


'cuda'

In [2]:
LMAX = 4
dataset = BoxesDataset(lmax=LMAX, n_samples=2)
sphten = io.SphericalTensor(lmax=LMAX, p_arg=1, p_val=1)
len(dataset)

2

In [3]:
model = S2ConvNet_Autoencoder(
    LMAX,
    l_list=[LMAX, 3, 3, 3, 2, 2, 1],
    channels=[4,   8, 8, 8, 16, 16, 32]
).to(device)
print(model)
print("number of parameters = ", sum([np.prod(x.shape) for x in model.parameters()]))



S2ConvNet_Autoencoder(
  (encoder): Sequential(
    (0): Linear(1x0e+1x1e+1x2e+1x3e+1x4e -> 1x0e+3x1e+5x2e+7x3e+9x4e | 25 weights)
    (1): SO3Activation (4 -> 4)
    (2): BatchNorm (1x0e+3x1e+5x2e+7x3e+9x4e, eps=1e-05, momentum=0.1)
    (3): Linear(1x0e+3x1e+5x2e+7x3e+9x4e -> 1x0e+3x1e+5x2e+7x3e+9x4e | 165 weights)
    (4): BatchNorm (1x0e+3x1e+5x2e+7x3e+9x4e, eps=1e-05, momentum=0.1)
    (5): SO3Activation (4 -> 3)
    (6): Linear(1x0e+3x1e+5x2e+7x3e -> 1x0e+3x1e+5x2e+7x3e | 84 weights)
    (7): BatchNorm (1x0e+3x1e+5x2e+7x3e, eps=1e-05, momentum=0.1)
    (8): SO3Activation (3 -> 3)
    (9): Linear(1x0e+3x1e+5x2e+7x3e -> 1x0e+3x1e+5x2e+7x3e | 84 weights)
    (10): BatchNorm (1x0e+3x1e+5x2e+7x3e, eps=1e-05, momentum=0.1)
    (11): SO3Activation (3 -> 3)
    (12): Linear(1x0e+3x1e+5x2e+7x3e -> 1x0e+3x1e+5x2e+7x3e | 84 weights)
    (13): BatchNorm (1x0e+3x1e+5x2e+7x3e, eps=1e-05, momentum=0.1)
    (14): SO3Activation (3 -> 2)
    (15): Linear(1x0e+3x1e+5x2e -> 1x0e+3x1e+5x2e | 35 weight



In [5]:
# from utils import get_checkpoints_dir

# checkpoint_path = f'{get_checkpoints_dir()}/old/checkpoint_final_state.pt'
# checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# model.load_state_dict(checkpoint)

# # load_model(model, 'l=1_final')

In [6]:
def dot_almost(a, b, c, eps=1e-5, put_assert=True):
    dt = (a * b).sum(dim=-1)
    if (dt.max() > (c + eps)) or (dt.min() < (c - eps)):
        print("dot product is ", dt, "expected", c)
        assert(not put_assert)
        return False
    return True


def orthogonal_vector(a, b):
    batch, dim = a.shape
    A = torch.stack([a, b], dim=1)
    A = torch.cat([A, torch.eye(dim).repeat(batch, 1, 1)], dim=1).permute([0, 2, 1]) # transpose
    q, r = torch.qr(A)
    axis = q[:, :, -1]

    dot_almost(a, axis, 0)
    dot_almost(b, axis, 0)
    dot_almost(axis, axis, 1)
    return axis

def rotate(R, vecs):
    return torch.einsum('nij,nj->ni',R, vecs)
    
def interpolate_in_1D(v1, v2, s):
    # todo we do linear interpolation because the signs might be different... why do we not have to care about this in other irreps?
    return v1 + (v2 - v1) * s

def interpolate_in_3D(v1, v2, s):
    norm1 = torch.norm(v1)
    norm2 = torch.norm(v2)
    
    dot = (v1 * v2).sum(dim=-1)
    cos_angle = dot / (norm1 * norm2)
    angle = torch.acos(cos_angle)
    axis = torch.cross(v1, v2, dim=-1)
    axis = axis / torch.norm(axis, dim=-1, keepdim=True)

    R = o3.axis_angle_to_matrix(axis=axis, angle=angle * s)
    
    R_full = o3.axis_angle_to_matrix(axis=axis, angle=angle)
    assert(dot_almost(v2, rotate(R_full, v1), norm1 * norm2))
    
    return rotate(R, v1) * torch.pow(norm2 / norm1, s)


def interpolate_in_highD(irrep, v1, v2, s):
    batch, dim = v1.shape
    A = torch.stack([v1, v2], dim=1)
    A = torch.cat([A, torch.eye(dim).repeat(batch, 1, 1)], dim=1).permute([0, 2, 1]) # transpose
    q, r = torch.qr(A)
    axis = q[:, :, :3] # this is our basis
    v1_3d = torch.einsum('ndi,nd->ni', axis, v1)
    v2_3d = torch.einsum('ndi,nd->ni', axis, v2)
    v_interp_3d = interpolate_in_3D(v1_3d, v2_3d, s)
    v_interp = torch.einsum('ndi,ni->nd', axis, v_interp_3d)
    return v_interp
    
    
    # norm1 = torch.linalg.norm(vec1, dim=-1, keepdim=True)
    # norm2 = torch.linalg.norm(vec2, dim=-1, keepdim=True)
    # vec1_norm = vec1 / norm1
    # vec2_norm = vec2 / norm2
    # axis = orthogonal_vector(vec1_norm, vec2_norm)
    # axis_norm = torch.norm(axis, dim=-1)
    # axis /= axis_norm.unsqueeze(-1)
    # # theta = torch.asin(axis_norm)

    # x_axis = vec1_norm
    # y_axis = vec2_norm - (vec2_norm * x_axis).sum(dim=-1).unsqueeze(-1) * x_axis
    # y_axis /= torch.linalg.norm(y_axis, dim=-1, keepdim=True)

    # # handle degenerate cases
    # cos = (vec1_norm * vec2_norm).sum(dim=-1)
    # y_axis = torch.where(
    #     (torch.abs(cos) > 1-(1e-4)).unsqueeze(-1),
    #     orthogonal_vector(x_axis, axis),
    #     y_axis
    # )
    
    # dot_almost(x_axis, x_axis, 1)
    # dot_almost(axis, axis, 1)
    # dot_almost(y_axis, y_axis, 1)

    # dot_almost(vec1_norm, axis, 0)
    # dot_almost(vec2_norm, axis, 0)

    # dot_almost(x_axis, y_axis, 0)
    # dot_almost(x_axis, axis, 0)
    # dot_almost(y_axis, axis, 0)

    # x_coord = (x_axis * vec2_norm).sum(dim=-1)
    # y_coord = (y_axis * vec2_norm).sum(dim=-1)

    # theta = torch.atan2(y_coord, x_coord)

    # R = o3.Irreps(f'1x{l}e').D_from_axis_angle(axis=axis, angle=theta * s)

    # should_be_vec2_norm1 = torch.einsum('nij,nj->ni', o3.Irreps(f'1x{l}e').D_from_axis_angle(axis=axis, angle=+theta), vec1_norm)
    # should_be_vec2_norm2 = torch.einsum('nij,nj->ni', o3.Irreps(f'1x{l}e').D_from_axis_angle(axis=axis, angle=-theta), vec1_norm)
    # print('---')
    # print(vec1)
    # print(vec2)
    # assert(dot_almost(should_be_vec2_norm1, vec2_norm, 1, eps=1e-2, put_assert=False) or dot_almost(should_be_vec2_norm2, vec2_norm, 1, eps=1e-2, put_assert=False))

    # vec_interp_size = vec1 * torch.pow(norm2 / norm1, s)
    # vec_interp_rotated = torch.einsum('nij,nj->ni', R, vec_interp_size)
    # vec_interp = vec_interp_rotated
    # res[..., ind:ind+sz] = vec_interp


def interpolate_in_latent_space(repr, latent1, latent2, s):
    res = torch.empty_like(latent1)
    assert(latent1.shape == latent2.shape)

    ind = 0
    for l in repr.ls:
        sz = 2 * l + 1
        vec1 = latent1[..., ind:ind+sz]
        vec2 = latent2[..., ind:ind+sz]

        if sz == 1:
            res[..., ind:ind+sz] = interpolate_in_1D(vec1, vec2, s)
        elif sz == 3:
            res[..., ind:ind+sz] = interpolate_in_3D(vec1, vec2, s)
        else:
            # raise Exception("l > 1 is not supported in interpolation sorry :))")
            res[..., ind:ind+sz] = interpolate_in_highD(o3.Irreps(f'{l}e'), vec1, vec2, s)
        ind += sz
    assert(ind == res.shape[-1])
    return res


def linear_interpolate_in_latent_space(latent1, latent2, s):
    return latent1 + s * (latent2 - latent1)



# this is just for debugging
def interpolate_in_latent_space_hint(axis, angle, repr, latent1, latent2, s):
    res = torch.empty_like(latent1)
    assert(latent1.shape == latent2.shape)

    ind = 0
    for l in repr.ls:
        sz = 2 * l + 1
        vec1 = latent1[..., ind:ind+sz]
        vec2 = latent2[..., ind:ind+sz]

        irrep = o3.Irreps(f'{l}e')
        D = irrep.D_from_axis_angle(axis=axis, angle=angle * s)
        res[..., ind:ind+sz] = torch.einsum('ij,nj->ni', D, vec1)
        ind += sz
    assert(ind == res.shape[-1])
    return res


In [27]:
latent_repr = model.latent_repr
model_sphten_repr = model.model_sphten_repr

sphten = model.model_sphten_repr

lmax = 4
# sh1 = dataset[0].squeeze()
# sh2 = dataset[1].squeeze()

axis = torch.tensor([0, 0, 1]).float()
angle = torch.tensor(torch.pi/2)
R = o3.axis_angle_to_matrix(axis, angle)
sphten = model.model_sphten_repr
D = sphten.D_from_matrix(R)

sh1 = dataset[0].squeeze()
sh2 = dataset[1].squeeze()


with torch.no_grad():
    sh1 = sh1.unsqueeze(0).to(device)
    sh2 = sh2.unsqueeze(0).to(device)
    inp1, latent1, out1 = model(sh1)
    inp2, latent2, out2 = model(sh2)

init_shape = latent1.shape
latent1 = latent1.reshape(-1).unsqueeze(0)
latent2 = latent2.reshape(-1).unsqueeze(0)

def save_sh(met, idx, sh):
    layout = go.Layout(scene=dict(aspectmode="cube",
                               xaxis=dict(title='X', showgrid=False, visible=False),
                               yaxis=dict(title='Y', showgrid=False, visible=False),
                               zaxis=dict(title='Z', showgrid=False, visible=False),
                               # Set background color to transparent
                               bgcolor='rgba(0, 0, 0, 0)'))
    p_val = 1
    p_arg = 1
    sphten = e3nn.io.SphericalTensor(lmax, p_val, p_arg)
    fig = go.Figure([go.Surface(sphten.plotly_surface(sh, radius=True)[0])], layout=layout)
    fig.update(layout_coloraxis_showscale=False)
    fig.update(layout_showlegend=False)
    fig.update_coloraxes(showscale=False)
    fig.write_image(f'{met}-{idx}.png')

    
N = 8
rows = 2
columns = N

i = 0
for j in range(columns):
    s = j/(N-1)
    latent = interpolate_in_latent_space(latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
    latent = latent.to(device)
    # latent = interpolate_in_latent_space_hint(axis, angle, latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
    latent = latent.reshape(init_shape)
    with torch.no_grad():
        out = model.decoder(latent)
    save_sh('equiv_rot', j, out.cpu())
    
        

i = 1
for j in range(columns):
    s = j/(N-1)
    latent = linear_interpolate_in_latent_space(latent1=latent1.cpu(), latent2=latent2.cpu(), s=s).to(device)
    latent = latent.to(device)
    latent = latent.reshape(init_shape)
    with torch.no_grad():
        out = model.decoder(latent)
    save_sh('lin_rot', j, out.cpu())

In [28]:
from PIL import Image
import numpy as np

def get_name(name):
    # List of image file paths
    image_files = [f"{name}-{i}.png" for i in range(N)]
    
    # Load images and extract centers
    images = []
    for file in image_files:
        img = Image.open(file)
        width, height = img.size
        # Calculate the center coordinates
        D = min(width, height) * 0.5
        left = (width - D) / 2
        top = (height - D) / 2
        right = (width + D) / 2
        bottom = (height + D) / 2
        # Crop the image to extract the center
        cropped_img = img.crop((left, top, right, bottom))
        images.append(cropped_img)
    
    # Convert images to NumPy arrays
    image_arrays = [np.array(img) for img in images]
    
    # Stack the images along the vertical axis
    stacked_image = np.hstack(image_arrays)
    # # Convert the stacked image array back to a PIL image
    stacked_pil_image = Image.fromarray(stacked_image)
    stacked_pil_image.save(f'{name}-stacked.png')
    
    return stacked_image        


stacked = np.vstack([get_name('equiv_rot'), get_name('lin_rot')])
stacked_pil_image = Image.fromarray(stacked)
stacked_pil_image.save(f'all-stacked.png')