# Demonstration of an SO3 Equivariant Autoencoder

In [163]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import e3nn
from e3nn import o3, io
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# import open3d as o3d
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io

# plotly.io.renderers.default = "notebook"

from tqdm.notebook import tqdm
from functools import partial, reduce


from utils import load_model, save_model, CustomLRScheduler
from data_generation import BoxesDataset, SimpleShapeGridDataset

# from model import EncoderDecoder
from model import S2ConvNet_Autoencoder, s2_irreps, so3_irreps
from visualize import visualize_points, visualize_signal
from losses import (
    GridLoss,
    WeightedGridLoss,
    WeightedPointLoss,
    WeightedGridLossWithRotation,
    IOULoss,
)

import time

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu' # for now batch=1 so no need for gpu
device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


'cuda'

# Inspect Data

In [164]:
LMAX = 4
dataset = BoxesDataset(lmax=LMAX, n_samples=100)
sphten = io.SphericalTensor(lmax=LMAX, p_arg=1, p_val=1)
len(dataset)

100

# Create the model

We found that repeated layers with the same lmax were necessary to learn.


In [170]:
# Note: This cell takes a while to run because we have to compile the e3nn modules
model = S2ConvNet_Autoencoder(
    LMAX,
    # l_list=[LMAX, 3, 3, 3, 2, 2, 2],
    l_list=[LMAX, 3, 3, 3, 2, 2, 1],
    channels=[4, 8, 8, 8, 16, 16, 32],
    # l_list=[LMAX, ],
    # channels=[4, ],
).to(device)
print(model)
print("number of parameters = ", sum([np.prod(x.shape) for x in model.parameters()]))

S2ConvNet_Autoencoder(
  (encoder): Sequential(
    (0): Linear(1x0e+1x1e+1x2e+1x3e+1x4e -> 1x0e+3x1e+5x2e+7x3e+9x4e | 25 weights)
    (1): SO3Activation (4 -> 4)
    (2): BatchNorm (1x0e+3x1e+5x2e+7x3e+9x4e, eps=1e-05, momentum=0.1)
    (3): Linear(1x0e+3x1e+5x2e+7x3e+9x4e -> 1x0e+3x1e+5x2e+7x3e+9x4e | 165 weights)
    (4): BatchNorm (1x0e+3x1e+5x2e+7x3e+9x4e, eps=1e-05, momentum=0.1)
    (5): SO3Activation (4 -> 3)
    (6): Linear(1x0e+3x1e+5x2e+7x3e -> 1x0e+3x1e+5x2e+7x3e | 84 weights)
    (7): BatchNorm (1x0e+3x1e+5x2e+7x3e, eps=1e-05, momentum=0.1)
    (8): SO3Activation (3 -> 3)
    (9): Linear(1x0e+3x1e+5x2e+7x3e -> 1x0e+3x1e+5x2e+7x3e | 84 weights)
    (10): BatchNorm (1x0e+3x1e+5x2e+7x3e, eps=1e-05, momentum=0.1)
    (11): SO3Activation (3 -> 3)
    (12): Linear(1x0e+3x1e+5x2e+7x3e -> 1x0e+3x1e+5x2e+7x3e | 84 weights)
    (13): BatchNorm (1x0e+3x1e+5x2e+7x3e, eps=1e-05, momentum=0.1)
    (14): SO3Activation (3 -> 2)
    (15): Linear(1x0e+3x1e+5x2e -> 1x0e+3x1e+5x2e | 35 weight

# Define Plotting Tools

In [166]:
def plot_predictions(model, datas, sphten):
    for data in datas:
        with torch.no_grad():
            inp, latent, out = model(data.to(device))
        fig = make_subplots(
            rows=1,
            cols=2,
            specs=[[{"is_3d": True} for j in range(2)] for i in range(1)],
        )
        fig.add_trace(
            go.Surface(sphten.plotly_surface(inp[0].cpu().detach(), radius=True)[0]),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Surface(sphten.plotly_surface(out[0].cpu().detach(), radius=True)[0]),
            row=1,
            col=2,
        )
        fig.show()


def get_iou(model, datas, lmax):
    iou = IOULoss(lmax)
    ious = []
    for data in datas:
        data = data.unsqueeze(0)  # Add batch dim
        with torch.no_grad():
            inp, latent, out = model(data.to(device))
        iou_val = iou.compute_iou(out, data.to(device))
        ious.append(iou_val.item())
    return ious

def get_mean_iou(model, dataloader, lmax):
    iou = IOULoss(lmax, n_points=100)
    ious = []
    for data_idx, data in enumerate(dataloader):
        # data = data.float().to(device)
        with torch.no_grad():
            inp, latent, out = model(data.to(device))
        iou_val = iou.compute_iou(out, data.to(device))
        ious.append(iou_val)
    ious = torch.concat(ious)
    mean_iou = sum(ious) / len(dataloader)
    return mean_iou, ious


def debug_iou(model, datas, lmax):
    iou = IOULoss(lmax, n_points=100)
    data_batch = []
    for data in datas:
        data_batch.append(data)
        if len(data_batch) == 10:
            break

    data_batch = torch.stack(data_batch)
    print(data_batch.shape)

    with torch.no_grad():
        inp, latent, out = model(data_batch.to(device))

    iou_val = iou.compute_iou(out, data_batch.to(device))
    print("Debug: ", iou_val)

def debug_iou2(model, datas, lmax):
    iou = IOULoss(lmax)
    ious = []
    outs = []
    datas_stacked = []
    for data in datas:
        data = data.unsqueeze(0)  # Add batch dim
        with torch.no_grad():
            inp, latent, out = model(data.to(device))

        outs.append(out)
        datas_stacked.append(data)
    outs = torch.cat(outs)
    datas_stacked = torch.cat(datas_stacked)
    iou_val = iou.compute_iou(outs, datas_stacked.to(device))
    print("debug2", iou_val)


class IOU_logger():
    def __init__(self):
        self.ious = []
        self.epochs = []

    def log_iou(self, epoch, model, dataloader, lmax):
        ious = get_iou(model, dataset, lmax)
        mean_iou = np.mean(ious)
        self.ious.append(mean_iou)
        self.epochs.append(epoch)




        # mean_iou, ious = get_mean_iou(model, dataloader, lmax)
        # self.ious.append(mean_iou.detach().cpu().item())
        # self.epochs.append(epoch)
        # print(mean_iou, ious)

        # debug_iou(model, dataset, lmax)
        # debug_iou2(model, dataset, lmax)
        # print("One by one: ", get_iou(model, dataset, lmax))


    def save_iou(self):
        np.save("logs/iou.npy", np.array(self.ious))
        np.save("logs/iou_epochs.npy", np.array(self.epochs))


# Train Model

In [167]:
def train(
    model,
    dataset,
    loss_fn,
    epochs=600,
    checkpoint_interval=50,
    initial_rl=1,
    batch_size=1,
    scheduler=None,
    optimizer=None,
    inspection_interval=None,
    inspection_func=None,
    trial_name="model",
    iou_interval=None,
    iou_logger=None,
):
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)

    dataloader = DataLoader(dataset, batch_size=batch_size)

    pbar = tqdm(range(epochs))
    losses = []
    last_epoch_loss = 0
    for epoch in pbar:
        cur_epoch_loss = 0
        for data_idx, data in enumerate(dataloader):
            data = data.float().to(device)
            inp, latent, out = model(data)
            # start = time.time()
            loss = loss_fn(out, inp)
            # print(f'Loss calculation time: {time.time() - start:.4f}')
            optimizer.zero_grad()
            loss.backward()
            max_grad = max(
                [
                    torch.linalg.norm(p.grad).item()
                    for p in model.parameters()
                    if p.grad is not None
                ]
            )
            optimizer.step()
            pbar.set_description(
                f'Epoch {epoch+1} DataIdx {data_idx} Loss: {loss.item():.6f} Last_Epoch_Loss={last_epoch_loss:.6f} MaxGrad={max_grad:.4f} Lr={optimizer.param_groups[0]["lr"]}'
            )
            losses.append(loss.item())
            cur_epoch_loss += loss.item()

        last_epoch_loss = cur_epoch_loss

        if (
            (inspection_interval is not None)
            and (inspection_func is not None)
            and (epoch % inspection_interval == 0)
        ):
            inspection_func()

        if (iou_interval is not None) and (epoch % iou_interval == 0) and (iou_logger is not None):
            iou_logger.log_iou(epoch, model, dataloader, LMAX)
            iou_logger.save_iou()

        if epoch % checkpoint_interval == 0:
            name = f"{trial_name}_epoch_{epoch}"
            save_model(model, name)
            print(f"saved {name}")

        if scheduler is not None:
            scheduler.step()
    return losses

In [None]:
# save_model(model, 'initial_state')
try:
    load_model(model, "model_final_state")
except:
    print("No model found")

In [171]:
torch.manual_seed(0)
initial_rl = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)
custom_scheduler = CustomLRScheduler(optimizer, initial_rl)
loss_fn = WeightedGridLoss(LMAX).to(device)
# loss_fn = IOULoss(LMAX, n_points=100).to(device)


def inspection_func():
    data = dataset[np.random.choice(len(dataset))]
    plot_predictions(model, [data], sphten)
    iou = get_iou(model, [data], lmax=LMAX)
    print(f"IoU: {iou}")

In [172]:
trial_name = "iou_lmax1_b1"

save_model(model, "initial_state")

custom_scheduler.set_rl(0.01)
losses = train(
    model=model,
    dataset=dataset,
    loss_fn=loss_fn,
    epochs=20000,
    initial_rl=None,
    scheduler=custom_scheduler,
    optimizer=optimizer,
    inspection_func=inspection_func,
    inspection_interval=500,
    checkpoint_interval=200,
    # batch_size=32,
    batch_size=1,
    trial_name=trial_name,
    iou_interval=1,
    iou_logger=IOU_logger(),
)
#    batch_size=100, trial_name=trial_name)

save_model(model, "final_state")
save_model(model, "initial_state")

plt.plot(losses)

  0%|          | 0/20000 [00:00<?, ?it/s]

IoU: [0.8676156997680664]
saved iou_lmax1_b1_epoch_0


In [None]:
logger = IOU_logger()
logger.log_iou(0, model, DataLoader(dataset, batch_size=32), LMAX)
logger.save_iou()

In [None]:
iou_log = np.load("logs/iou.npy")
print(iou_log)

# Test Equivariance

In [None]:
load_model(model, "iou_lmax1_epoch_2400")

In [None]:
# For each box, rotate for n times and calculate equivariance
def get_equivariance_along_axis(
    model,
    data,
    axis: str,
    lmax_input: int,
    lmax_latent: int,
    n_rotations,
):
    # axis in "x", "y", "z"

    irreps_input = s2_irreps(lmax_input)
    irreps_latent = so3_irreps(lmax_latent)


    angles = [torch.tensor(np.pi / n_rotations * i) for i in range(n_rotations)]
    axis = torch.tensor([int(axis=="x"), int(axis=="y"), int(axis=="z")], dtype=torch.float32)

    mean_deviations = []
    for angle in angles:
        R_input = irreps_input.D_from_axis_angle(axis, angle)
        R_latent = irreps_latent.D_from_axis_angle(axis, angle)

        val1 = torch.einsum("ij, ...j->...i", R_input, data)
        val1 = model.encoder(val1.to(device)).detach().cpu()

        val2 = model.encoder(data.to(device)).detach().cpu()
        val2 = torch.einsum("ij, ...j->...i", R_latent, val2)

        mean_deviations.append(torch.mean(torch.abs(val1 - val2)).item())

    return torch.tensor(mean_deviations)

def compute_all_equivariance(model=model, dataset=dataset, n_samples=10, n_rotations=10):
    equivariance_results = {}

    # data_batch = torch.stack([data for data in dataset])
    data_batch = dataset[0].unsqueeze(0)

    for axis in ["x", "y", "z"]:
        running_mean = torch.zeros((n_rotations,))
        pbar = tqdm(dataset)
        for data in pbar:
            data_batch = data.unsqueeze(0)
            mean_deviations = get_equivariance_along_axis(
                model, data_batch, axis, lmax_input=LMAX, lmax_latent=1, n_rotations=n_rotations
            )
            running_mean += mean_deviations

        running_mean /= len(dataset)
        equivariance_results[axis] = running_mean
    return equivariance_results

results = compute_all_equivariance(model, dataset)
print(results)

In [None]:
def print_results(results):
    for axis in results:
        line = f"${axis}$ "
        for i, val in enumerate(results[axis]):
            line += f"& {val:.3f}"
        line += " \\\\"
        print(line)

print_results(results)

# Test Reconstruction

# Test Interpolation

In [None]:
# from e3nn.util.test import assert_equivariant

# assert_equivariant(model.encoder, irreps_in=[model.model_sphten_repr], irreps_out=[model.latent_repr])

In [None]:
# from e3nn.util.test import assert_equivariant
# from convolution import so3_irreps, s2_irreps, SO3Activation

# def fix_dim(x):
#     print(x.shape)
#     return x.reshape((x.shape[0], -1))

# def print_dim(x):
#     print(x.shape)
#     return x

# model_sphten_repr = io.SphericalTensor(lmax=4, p_val=1, p_arg=1)

# # inner_layer_repr = 16 * so3_irreps(2)
# # inner_layer_repr = 1 * so3_irreps(4)
# # assert_equivariant(func=lambda x: fix_dim(model.encoder[:2](print_dim(x.unsqueeze(1).to(device)))).cpu(), irreps_in=[model_sphten_repr], irreps_out=[inner_layer_repr])

# inner_layer_repr = model.latent_repr
# assert_equivariant(func=lambda x: fix_dim(model.encoder(print_dim(x.unsqueeze(1).to(device)))).cpu(), irreps_in=[model_sphten_repr], irreps_out=[inner_layer_repr])


# # lmax = 4
# # encoder_list = []
# # encoder_list.append(e3nn.o3.Linear(s2_irreps(lmax), so3_irreps(lmax), f_in=1, f_out=1, internal_weights=True))
# # encoder_list.append(SO3Activation(lmax, lmax, torch.relu, 11))
# # sample = nn.Sequential(*encoder_list)

# # assert_equivariant(func=lambda x: fix_dim(sample(print_dim(x.unsqueeze(1)))), irreps_in=[model_sphten_repr], irreps_out=[inner_layer_repr], tolerance=1e-2)


In [None]:
def dot_almost(a, b, c, eps=1e-5, put_assert=True):
    dt = (a * b).sum(dim=-1)
    if (dt.max() > (c + eps)) or (dt.min() < (c - eps)):
        print("dot product is ", dt, "expected", c)
        assert not put_assert
        return False
    return True


def orthogonal_vector(a, b):
    batch, dim = a.shape
    A = torch.stack([a, b], dim=1)
    A = torch.cat([A, torch.eye(dim).repeat(batch, 1, 1)], dim=1).permute(
        [0, 2, 1]
    )  # transpose
    q, r = torch.qr(A)
    axis = q[:, :, -1]

    dot_almost(a, axis, 0)
    dot_almost(b, axis, 0)
    dot_almost(axis, axis, 1)
    return axis


def rotate(R, vecs):
    return torch.einsum("nij,nj->ni", R, vecs)


def interpolate_in_1D(v1, v2, s):
    # todo we do linear interpolation because the signs might be different... why do we not have to care about this in other irreps?
    return v1 + (v2 - v1) * s


def interpolate_in_3D(v1, v2, s):
    norm1 = torch.norm(v1)
    norm2 = torch.norm(v2)

    eps = 1e-4
    if norm1 <= eps or norm2 <= eps:
        # just do linear interp
        return v1 + (v2 - v1) * s

    dot = (v1 * v2).sum(dim=-1)
    cos_angle = dot / (norm1 * norm2)
    cos_angle = torch.clip(cos_angle, -1, 1)  # to void nans
    angle = torch.acos(cos_angle)
    axis = torch.cross(v1, v2, dim=-1)
    axis = axis / torch.norm(axis, dim=-1, keepdim=True)

    R = o3.axis_angle_to_matrix(axis=axis, angle=angle * s)

    R_full = o3.axis_angle_to_matrix(axis=axis, angle=angle)
    assert dot_almost(v2, rotate(R_full, v1), norm1 * norm2)

    return rotate(R, v1) * torch.pow(norm2 / norm1, s)


def interpolate_in_highD(irrep, v1, v2, s):
    batch, dim = v1.shape
    A = torch.stack([v1, v2], dim=1)
    A = torch.cat([A, torch.eye(dim).repeat(batch, 1, 1)], dim=1).permute(
        [0, 2, 1]
    )  # transpose
    q, r = torch.qr(A)
    axis = q[:, :, :3]  # this is our basis
    v1_3d = torch.einsum("ndi,nd->ni", axis, v1)
    v2_3d = torch.einsum("ndi,nd->ni", axis, v2)
    v_interp_3d = interpolate_in_3D(v1_3d, v2_3d, s)
    v_interp = torch.einsum("ndi,ni->nd", axis, v_interp_3d)
    return v_interp

    # norm1 = torch.linalg.norm(vec1, dim=-1, keepdim=True)
    # norm2 = torch.linalg.norm(vec2, dim=-1, keepdim=True)
    # vec1_norm = vec1 / norm1
    # vec2_norm = vec2 / norm2
    # axis = orthogonal_vector(vec1_norm, vec2_norm)
    # axis_norm = torch.norm(axis, dim=-1)
    # axis /= axis_norm.unsqueeze(-1)
    # # theta = torch.asin(axis_norm)

    # x_axis = vec1_norm
    # y_axis = vec2_norm - (vec2_norm * x_axis).sum(dim=-1).unsqueeze(-1) * x_axis
    # y_axis /= torch.linalg.norm(y_axis, dim=-1, keepdim=True)

    # # handle degenerate cases
    # cos = (vec1_norm * vec2_norm).sum(dim=-1)
    # y_axis = torch.where(
    #     (torch.abs(cos) > 1-(1e-4)).unsqueeze(-1),
    #     orthogonal_vector(x_axis, axis),
    #     y_axis
    # )

    # dot_almost(x_axis, x_axis, 1)
    # dot_almost(axis, axis, 1)
    # dot_almost(y_axis, y_axis, 1)

    # dot_almost(vec1_norm, axis, 0)
    # dot_almost(vec2_norm, axis, 0)

    # dot_almost(x_axis, y_axis, 0)
    # dot_almost(x_axis, axis, 0)
    # dot_almost(y_axis, axis, 0)

    # x_coord = (x_axis * vec2_norm).sum(dim=-1)
    # y_coord = (y_axis * vec2_norm).sum(dim=-1)

    # theta = torch.atan2(y_coord, x_coord)

    # R = o3.Irreps(f'1x{l}e').D_from_axis_angle(axis=axis, angle=theta * s)

    # should_be_vec2_norm1 = torch.einsum('nij,nj->ni', o3.Irreps(f'1x{l}e').D_from_axis_angle(axis=axis, angle=+theta), vec1_norm)
    # should_be_vec2_norm2 = torch.einsum('nij,nj->ni', o3.Irreps(f'1x{l}e').D_from_axis_angle(axis=axis, angle=-theta), vec1_norm)
    # print('---')
    # print(vec1)
    # print(vec2)
    # assert(dot_almost(should_be_vec2_norm1, vec2_norm, 1, eps=1e-2, put_assert=False) or dot_almost(should_be_vec2_norm2, vec2_norm, 1, eps=1e-2, put_assert=False))

    # vec_interp_size = vec1 * torch.pow(norm2 / norm1, s)
    # vec_interp_rotated = torch.einsum('nij,nj->ni', R, vec_interp_size)
    # vec_interp = vec_interp_rotated
    # res[..., ind:ind+sz] = vec_interp


def interpolate_in_latent_space(repr, latent1, latent2, s):
    res = torch.empty_like(latent1)
    assert latent1.shape == latent2.shape

    ind = 0
    for l in repr.ls:
        sz = 2 * l + 1
        vec1 = latent1[..., ind : ind + sz]
        vec2 = latent2[..., ind : ind + sz]

        if sz == 1:
            res[..., ind : ind + sz] = interpolate_in_1D(vec1, vec2, s)
        elif sz == 3:
            res[..., ind : ind + sz] = interpolate_in_3D(vec1, vec2, s)
        else:
            raise Exception("l > 1 is not supported in interpolation sorry :))")
            # res[..., ind:ind+sz] = interpolate_in_highD(o3.Irreps(f'{l}e'), vec1, vec2, s)
        ind += sz
    assert ind == res.shape[-1]
    return res


def linear_interpolate_in_latent_space(latent1, latent2, s):
    return latent1 + s * (latent2 - latent1)


# this is just for debugging
def interpolate_in_latent_space_hint(axis, angle, repr, latent1, latent2, s):
    res = torch.empty_like(latent1)
    assert latent1.shape == latent2.shape

    ind = 0
    for l in repr.ls:
        sz = 2 * l + 1
        vec1 = latent1[..., ind : ind + sz]
        vec2 = latent2[..., ind : ind + sz]

        irrep = o3.Irreps(f"{l}e")
        D = irrep.D_from_axis_angle(axis=axis, angle=angle * s)
        res[..., ind : ind + sz] = torch.einsum("ij,nj->ni", D, vec1)
        ind += sz
    assert ind == res.shape[-1]
    return res

In [None]:
# latent_repr = model.latent_repr
# model_sphten_repr = model.model_sphten_repr

# axis = torch.tensor([1, 1, 1]).float()
# angle = torch.tensor(torch.pi)
# R = o3.axis_angle_to_matrix(axis, angle)
# sphten = model.model_sphten_repr
# D = sphten.D_from_matrix(R)

# sh1 = dataset[0].squeeze()
# sh2 = D @ sh1

# with torch.no_grad():
#     sh1 = sh1.unsqueeze(0).to(device)
#     sh2 = sh2.unsqueeze(0).to(device)
#     inp1, latent1, out1 = model(sh1)
#     inp2, latent2, out2 = model(sh2)

# init_shape = latent1.shape
# latent1 = latent1.reshape(-1).unsqueeze(0)
# latent2 = latent2.reshape(-1).unsqueeze(0)

# N = 4
# rows = 4
# columns = N

# fig = make_subplots(rows=rows, cols=columns, specs=[[{'is_3d': True} for j in range(columns)] for i in range(rows)])

# # rotated input
# i = 0
# for j in range(columns):
#     s = j/(N-1)
#     D = sphten.D_from_axis_angle(axis=axis, angle=angle*s)
#     sh = (D.to(device) @ sh1.squeeze(0)).unsqueeze(0)
#     fig.add_trace(go.Surface(sphten.plotly_surface(sh.cpu(), radius=True)[0]), row=i+1, col=j+1)


# # # rotated latent space
# # i = 1
# # for j in range(columns):
# #     s = j/(N-1)
# #     D = (latent_repr).D_from_axis_angle(axis=axis, angle=angle*s)
# #     latent = (D.to(device) @ latent1.squeeze(0)).unsqueeze(0)
# #     latent = latent.reshape(init_shape)
# #     with torch.no_grad():
# #         out = model.decoder(latent)
# #     fig.add_trace(go.Surface(sphten.plotly_surface(out.cpu(), radius=True)[0]), row=i+1, col=j+1)


# i = 1
# for j in range(columns):
#     s = j/(N-1)
#     # latent = interpolate_in_latent_space(latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
#     latent = interpolate_in_latent_space_hint(axis, angle, latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
#     latent = latent.reshape(init_shape)
#     with torch.no_grad():
#         out = model.decoder(latent)
#     fig.add_trace(go.Surface(sphten.plotly_surface(out.cpu(), radius=True)[0]), row=i+1, col=j+1)


# i = 2
# for j in range(columns):
#     s = j/(N-1)
#     latent = linear_interpolate_in_latent_space(latent1=latent1.cpu(), latent2=latent2.cpu(), s=s).to(device)
#     latent = latent.to(device)
#     latent = latent.reshape(init_shape)
#     with torch.no_grad():
#         out = model.decoder(latent)
#     fig.add_trace(go.Surface(sphten.plotly_surface(out.cpu(), radius=True)[0]), row=i+1, col=j+1)


# fig.show()
# # inp1 = inp1.cpu()
# # inp2 = inp2.cpu()
# # out1 = out1.cpu()
# # out2 = out2.cpu()
# # sh1 = sh1.cpu()
# # sh2 = sh2.cpu()
# # fig.add_trace(go.Surface(sphten.plotly_surface(sh1, radius=True)[0]), row=i+1, col=1)
# # fig.add_trace(go.Surface(sphten.plotly_surface(sh2, radius=True)[0]), row=i+1, col=columns)

# fig.write_image('interp.pdf')

In [None]:
# latent_repr = model.latent_repr
# model_sphten_repr = model.model_sphten_repr

# sphten = model.model_sphten_repr

# # sh1 = dataset[0].squeeze()
# # sh2 = dataset[1].squeeze()

# axis = torch.tensor([0, 0, 1]).float()
# angle = torch.tensor(torch.pi/2)
# R = o3.axis_angle_to_matrix(axis, angle)
# sphten = model.model_sphten_repr
# D = sphten.D_from_matrix(R)

# sh1 = dataset[0].squeeze()
# sh2 = D @ sh1


# with torch.no_grad():
#     sh1 = sh1.unsqueeze(0).to(device)
#     sh2 = sh2.unsqueeze(0).to(device)
#     inp1, latent1, out1 = model(sh1)
#     inp2, latent2, out2 = model(sh2)

# init_shape = latent1.shape
# latent1 = latent1.reshape(-1).unsqueeze(0)
# latent2 = latent2.reshape(-1).unsqueeze(0)

# N = 7
# rows = 2
# columns = N

# fig = make_subplots(rows=rows, cols=columns, specs=[[{'is_3d': True} for j in range(columns)] for i in range(rows)])

# i = 0
# for j in range(columns):
#     s = j/(N-1)
#     latent = interpolate_in_latent_space(latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
#     latent = latent.to(device)
#     # latent = interpolate_in_latent_space_hint(axis, angle, latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
#     latent = latent.reshape(init_shape)
#     with torch.no_grad():
#         out = model.decoder(latent)
#     fig.add_trace(go.Surface(sphten.plotly_surface(out.cpu(), radius=True)[0]), row=i+1, col=j+1)


# i = 1
# for j in range(columns):
#     s = j/(N-1)
#     latent = linear_interpolate_in_latent_space(latent1=latent1.cpu(), latent2=latent2.cpu(), s=s).to(device)
#     latent = latent.to(device)
#     latent = latent.reshape(init_shape)
#     with torch.no_grad():
#         out = model.decoder(latent)
#     fig.add_trace(go.Surface(sphten.plotly_surface(out.cpu(), radius=True)[0]), row=i+1, col=j+1)


# fig.show()
# # inp1 = inp1.cpu()
# # inp2 = inp2.cpu()
# # out1 = out1.cpu()
# # out2 = out2.cpu()
# # sh1 = sh1.cpu()
# # sh2 = sh2.cpu()
# # fig.add_trace(go.Surface(sphten.plotly_surface(sh1, radius=True)[0]), row=i+1, col=1)
# # fig.add_trace(go.Surface(sphten.plotly_surface(sh2, radius=True)[0]), row=i+1, col=columns)

# fig.write_image('interp.pdf')

In [None]:
# # equivariance check

# # R = o3.axis_angle_to_matrix(torch.tensor([1, 1, 1]).float(), torch.tensor(torch.pi/3))
# axis = torch.tensor([1, 1, 1]).float()
# angle = torch.tensor(torch.pi)
# R = o3.axis_angle_to_matrix(axis, angle)
# sphten = model.model_sphten_repr
# D = sphten.D_from_matrix(R)

# sh1 = dataset[0].squeeze()
# sh2 = D @ sh1

# with torch.no_grad():
#     sh1 = sh1.unsqueeze(0).to(device)
#     sh2 = sh2.unsqueeze(0).to(device)
#     sh1, latent1, out1 = model(sh1)
#     sh2, latent2, out2 = model(sh2)

# fig = make_subplots(rows=3, cols=2, specs=[[{'is_3d': True} for j in range(2)] for i in range(3)])
# fig.add_trace(go.Surface(sphten.plotly_surface(sh1.cpu(), radius=True)[0]), row=1, col=1)
# fig.add_trace(go.Surface(sphten.plotly_surface(D @ sh1.squeeze().cpu(), radius=True)[0]), row=2, col=1)
# fig.add_trace(go.Surface(sphten.plotly_surface(sh2.cpu(), radius=True)[0]), row=3, col=1)

# fig.add_trace(go.Surface(sphten.plotly_surface(out1.cpu(), radius=True)[0]), row=1, col=2)
# fig.add_trace(go.Surface(sphten.plotly_surface(D @ out1.squeeze().cpu(), radius=True)[0]), row=2, col=2)
# fig.add_trace(go.Surface(sphten.plotly_surface(out2.cpu(), radius=True)[0]), row=3, col=2)


In [None]:
latent_repr = model.latent_repr
model_sphten_repr = model.model_sphten_repr

sphten = model.model_sphten_repr

lmax = 4
# sh1 = dataset[0].squeeze()
# sh2 = dataset[1].squeeze()

axis = torch.tensor([0, 0, 1]).float()
angle = torch.tensor(torch.pi / 2)
R = o3.axis_angle_to_matrix(axis, angle)
sphten = model.model_sphten_repr
D = sphten.D_from_matrix(R)

sh1 = dataset[0].squeeze()
sh2 = dataset[1].squeeze()


with torch.no_grad():
    sh1 = sh1.unsqueeze(0).to(device)
    sh2 = sh2.unsqueeze(0).to(device)
    inp1, latent1, out1 = model(sh1)
    inp2, latent2, out2 = model(sh2)

init_shape = latent1.shape
latent1 = latent1.reshape(-1).unsqueeze(0)
latent2 = latent2.reshape(-1).unsqueeze(0)


def save_sh(met, idx, sh):
    layout = go.Layout(
        scene=dict(
            aspectmode="cube",
            xaxis=dict(title="X", showgrid=False, visible=False),
            yaxis=dict(title="Y", showgrid=False, visible=False),
            zaxis=dict(title="Z", showgrid=False, visible=False),
            # Set background color to transparent
            bgcolor="rgba(0, 0, 0, 0)",
        )
    )
    p_val = 1
    p_arg = 1
    sphten = e3nn.io.SphericalTensor(lmax, p_val, p_arg)
    fig = go.Figure(
        [go.Surface(sphten.plotly_surface(sh, radius=True)[0])], layout=layout
    )
    fig.update(layout_coloraxis_showscale=False)
    fig.update(layout_showlegend=False)
    fig.update_coloraxes(showscale=False)
    fig.write_image(f"{met}-{idx}.png")


N = 8
rows = 2
columns = N

i = 0
for j in range(columns):
    s = j / (N - 1)
    latent = interpolate_in_latent_space(
        latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j / (N - 1)
    ).to(device)
    latent = latent.to(device)
    # latent = interpolate_in_latent_space_hint(axis, angle, latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j/(N-1)).to(device)
    latent = latent.reshape(init_shape)
    with torch.no_grad():
        out = model.decoder(latent)
    save_sh("equiv_rot", j, out.cpu())


i = 1
for j in range(columns):
    s = j / (N - 1)
    latent = linear_interpolate_in_latent_space(
        latent1=latent1.cpu(), latent2=latent2.cpu(), s=s
    ).to(device)
    latent = latent.to(device)
    latent = latent.reshape(init_shape)
    with torch.no_grad():
        out = model.decoder(latent)
    save_sh("lin_rot", j, out.cpu())

In [None]:
from PIL import Image
import numpy as np


def get_name(name):
    # List of image file paths
    image_files = [f"{name}-{i}.png" for i in range(N)]

    # Load images and extract centers
    images = []
    for file in image_files:
        img = Image.open(file)
        width, height = img.size
        # Calculate the center coordinates
        D = min(width, height) * 0.5
        left = (width - D) / 2
        top = (height - D) / 2
        right = (width + D) / 2
        bottom = (height + D) / 2
        # Crop the image to extract the center
        cropped_img = img.crop((left, top, right, bottom))
        images.append(cropped_img)

    # Convert images to NumPy arrays
    image_arrays = [np.array(img) for img in images]

    # Stack the images along the vertical axis
    stacked_image = np.hstack(image_arrays)
    # # Convert the stacked image array back to a PIL image
    stacked_pil_image = Image.fromarray(stacked_image)
    stacked_pil_image.save(f"{name}-stacked.png")

    return stacked_image


stacked = np.vstack([get_name("equiv_rot"), get_name("lin_rot")])
stacked_pil_image = Image.fromarray(stacked)
stacked_pil_image.save(f"all-stacked.png")

In [None]:
save_model(model, "train_100_boxes_l1")

In [None]:
dataset = [dataset[0], dataset[1]]  # the two things we interpolated between

## Get the interpolation with the exact same two objects but with network only trained on two objects

In [None]:
model = S2ConvNet_Autoencoder(
    LMAX, l_list=[LMAX, 3, 3, 3, 2, 2, 1], channels=[4, 8, 8, 8, 16, 16, 32]
).to(device)
print(model)
print("number of parameters = ", sum([np.prod(x.shape) for x in model.parameters()]))

In [None]:
torch.manual_seed(0)
initial_rl = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=initial_rl)
custom_scheduler = CustomLRScheduler(optimizer, initial_rl)
loss_fn = WeightedGridLoss(LMAX).to(device)


def inspection_func():
    data = dataset[np.random.choice(len(dataset))]
    plot_predictions(model, [data], sphten)

In [None]:
save_model(model, "initial_state")

custom_scheduler.set_rl(0.001)
losses = train(
    model=model,
    dataset=dataset,
    loss_fn=loss_fn,
    epochs=5000,
    initial_rl=None,
    scheduler=custom_scheduler,
    optimizer=optimizer,
    inspection_func=inspection_func,
    inspection_interval=200,
    checkpoint_interval=200,
    batch_size=1,
)

save_model(model, "final_state")
save_model(model, "initial_state")

plt.plot(losses)

In [None]:
# save_model(model, "train_2_boxes_l1")
# load_model(model, "train_2_boxes_l1")
load_model(model, "train_100_boxes_l1")

In [None]:
import random

dataset = list(dataset)
random.shuffle(dataset)

In [None]:
latent_repr = model.latent_repr
model_sphten_repr = model.model_sphten_repr

sphten = model.model_sphten_repr

lmax = 4
sh1 = dataset[0].squeeze()
sh2 = dataset[1].squeeze()

axis = torch.tensor([0, 0, 1]).float()
angle = torch.tensor(torch.pi / 2)
R = o3.axis_angle_to_matrix(axis, angle)
sphten = model.model_sphten_repr
D = sphten.D_from_matrix(R)

sh1 = dataset[0].squeeze()
sh2 = dataset[1].squeeze()

with torch.no_grad():
    sh1 = sh1.unsqueeze(0).to(device)
    sh2 = sh2.unsqueeze(0).to(device)
    inp1, latent1, out1 = model(sh1)
    inp2, latent2, out2 = model(sh2)

save_sh("sh", 1, sh1.cpu())
save_sh("sh", 2, sh2.cpu())
save_sh("out", 1, out1.cpu())
save_sh("out", 2, out2.cpu())

init_shape = latent1.shape
latent1 = latent1.reshape(-1).unsqueeze(0)
latent2 = latent2.reshape(-1).unsqueeze(0)


def save_sh(met, idx, sh):
    assert sh.isnan().any() == False
    layout = go.Layout(
        scene=dict(
            aspectmode="cube",
            xaxis=dict(title="X", showgrid=False, visible=False),
            yaxis=dict(title="Y", showgrid=False, visible=False),
            zaxis=dict(title="Z", showgrid=False, visible=False),
            # Set background color to transparent
            bgcolor="rgba(0, 0, 0, 0)",
        )
    )
    p_val = 1
    p_arg = 1
    sphten = e3nn.io.SphericalTensor(lmax, p_val, p_arg)
    fig = go.Figure(
        [go.Surface(sphten.plotly_surface(sh, radius=True)[0])], layout=layout
    )
    fig.update(layout_coloraxis_showscale=False)
    fig.update(layout_showlegend=False)
    fig.update_coloraxes(showscale=False)
    fig.write_image(f"{met}-{idx}.png")


N = 8
rows = 2
columns = N

i = 0
for j in range(columns):
    s = j / (N - 1)
    latent = interpolate_in_latent_space(
        latent_repr, latent1=latent1.cpu(), latent2=latent2.cpu(), s=j / (N - 1)
    ).to(device)
    latent = latent.to(device)
    latent = latent.reshape(init_shape)
    with torch.no_grad():
        out = model.decoder(latent)
    save_sh("equiv_rot", j, out.cpu())


i = 1
for j in range(columns):
    s = j / (N - 1)
    latent = linear_interpolate_in_latent_space(
        latent1=latent1.cpu(), latent2=latent2.cpu(), s=s
    ).to(device)
    latent = latent.to(device)
    latent = latent.reshape(init_shape)
    with torch.no_grad():
        out = model.decoder(latent)
    save_sh("lin_rot", j, out.cpu())

In [None]:
from PIL import Image
import numpy as np


def get_name(name):
    # List of image file paths
    image_files = [f"{name}-{i}.png" for i in range(N)]

    # Load images and extract centers
    images = []
    for file in image_files:
        img = Image.open(file)
        width, height = img.size
        # Calculate the center coordinates
        D = min(width, height) * 0.5
        left = (width - D) / 2
        top = (height - D) / 2
        right = (width + D) / 2
        bottom = (height + D) / 2
        # Crop the image to extract the center
        cropped_img = img.crop((left, top, right, bottom))
        images.append(cropped_img)

    # Convert images to NumPy arrays
    image_arrays = [np.array(img) for img in images]

    # Stack the images along the vertical axis
    stacked_image = np.hstack(image_arrays)
    # # Convert the stacked image array back to a PIL image
    stacked_pil_image = Image.fromarray(stacked_image)
    stacked_pil_image.save(f"{name}-stacked.png")

    return stacked_image


stacked = np.vstack([get_name("equiv_rot"), get_name("lin_rot")])
stacked_pil_image = Image.fromarray(stacked)
stacked_pil_image.save(f"all-stacked.png")