In [1]:
from o2s.models.object_2_sphere import Mesh2Sphere
from o2s.datasets.radar_dataset import RadarDataset
from o2s.datasets.transforms.general import Compose
from o2s.datasets.transforms.radar import Log, Abs, Normalize, Center
from o2s.datasets.transforms.mesh import MeshNormalize

import numpy as np
import matplotlib.pyplot as plt
from torch_harmonics import plotting
from torch_geometric.loader import DataLoader
import xarray
import torch
from torch import nn
from pathlib import Path

from hydra_zen import load_from_yaml, instantiate, get_target
from mlflow import MlflowClient, set_tracking_uri

In [2]:
set_tracking_uri('/home/colin/hdd/workspace/object2sphere/datasets/out/mlflow')
runs = MlflowClient().search_runs(
    experiment_ids="711364059123036807",
)

In [None]:
output_dir = runs[20].data.params['output_dir']
cfg = load_from_yaml(str(list(Path(output_dir).glob("**/config.yaml"))[0]))
#b7cdf38de03d4257aacb39457acc795c
equiformer = get_target(cfg.module)
equiformer = equiformer.load_from_checkpoint(
    output_dir + '/last.ckpt',
    backbone=instantiate(cfg.module.backbone),
    criterion=instantiate(cfg.module.criterion),
    optim=instantiate(cfg.module.optim),
)



In [None]:
output_dir = runs[21].data.params['output_dir']
cfg = load_from_yaml(str(list(Path(output_dir).glob("**/config.yaml"))[0]))
#b7cdf38de03d4257aacb39457acc795c
o2s = get_target(cfg.module)
o2s = o2s.load_from_checkpoint(
    output_dir + '/last.ckpt',
    backbone=instantiate(cfg.module.backbone),
    criterion=instantiate(cfg.module.criterion),
    optim=instantiate(cfg.module.optim),
)

In [None]:
#cfg = load_from_yaml('/home/colin/hdd/workspace/object2sphere/config/rem_config.yaml')
cfg = load_from_yaml(str(list(Path(output_dir).glob("**/config.yaml"))[0]))
test_ds_fp = '/home/colin/hdd/workspace/object2sphere/datasets/radar_test.nc'
ds = RadarDataset(
    test_ds_fp,
    'test',
    seed=0,
    transform=instantiate(cfg.test_dataset.transform),
    shuffle_before_split=False,
    mesh_mode='simple',
    orientation_mode='full',
)

In [None]:
dl = DataLoader(ds, batch_size=10, num_workers=0, shuffle=True)
itr = iter(dl)

In [None]:
mesh, target  = next(itr)
with torch.no_grad():
    equiformer_pred, _ = equiformer(mesh.cuda())
    o2s_pred, _ = o2s(mesh.cuda())

In [None]:
import matplotlib.animation as animation

fig = plt.figure(layout='constrained', figsize=(12, 8), dpi=72)
subfigs = fig.subfigures(1, 3)

moviewriter = animation.FFMpegWriter(fps=30)
moviewriter.setup(fig, 'compare.mp4', dpi=72)

b = 0
num_frames = target.size(1)
for frame in range(num_frames):
    if frame % 10 == 0:
        print(f"frame={frame}")
    
    plotting.plot_spherical_fn(
        o2s_pred[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[0], 
        vmin=target.cpu().min(),
        vmax=target.cpu().max(),
        title=f"R={frame}"
    )
    plotting.plot_spherical_fn(
        equiformer_pred[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[1], 
        vmin=target.cpu().min(),
        vmax=target.cpu().max(),
        title=f"R={frame}"
    )
    plotting.plot_spherical_fn(
        target[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[2], 
        vmin=target.cpu().min(),
        vmax=target.cpu().max(),
        title=f"R={frame}"
    )

    plt.draw()
    moviewriter.grab_frame()
    subfigs[0].clear()
    subfigs[1].clear()
    subfigs[2].clear()

moviewriter.finish()
plt.close()

<video controls src="compare1.mp4" />