In [1]:
from g2s.models.geometry_2_sphere import Mesh2Radar
from g2s.datasets.radar_dataset import RadarDataset
from g2s.datasets.transforms.general import Compose
from g2s.datasets.transforms.radar import Log, Abs, Normalize, Center
from g2s.datasets.transforms.mesh import MeshNormalize

import numpy as np
import matplotlib.pyplot as plt
from torch_harmonics import plotting
from torch_geometric.loader import DataLoader
import xarray
import torch
from torch import nn
from pathlib import Path

from hydra_zen import load_from_yaml, instantiate, get_target
from mlflow import MlflowClient, set_tracking_uri

In [2]:
set_tracking_uri('geometry2sphere/datasets/out/mlflow')
runs = MlflowClient().search_runs(
    experiment_ids="711364059123036807",
)

In [None]:
output_dir = runs[12].data.params['output_dir']
cfg = load_from_yaml(str(list(Path(output_dir).glob("**/config.yaml"))[0]))
equiformer = get_target(cfg.module)
equiformer = equiformer.load_from_checkpoint(
    output_dir + '/last.ckpt',
    backbone=instantiate(cfg.module.backbone),
    criterion=instantiate(cfg.module.criterion),
    optim=instantiate(cfg.module.optim),
)

In [None]:
output_dir = runs[1].data.params['output_dir']
cfg = load_from_yaml(str(list(Path(output_dir).glob("**/config.yaml"))[0]))
o2s = get_target(cfg.module)
o2s = o2s.load_from_checkpoint(
    output_dir + '/last.ckpt',
    backbone=instantiate(cfg.module.backbone),
    criterion=instantiate(cfg.module.criterion),
    optim=instantiate(cfg.module.optim),
)

In [4]:
cfg = load_from_yaml(str(list(Path(output_dir).glob("**/config.yaml"))[0]))
test_ds_fp = 'geometry2sphere/datasets/radar_test.nc'
ds = RadarDataset(
    test_ds_fp,
    'test',
    seed=0,
    transform=instantiate(cfg.test_dataset.transform),
    shuffle_before_split=False,
    mesh_mode='simple',
    orientation_mode='full',
)

In [5]:
dl = DataLoader(ds, batch_size=2, num_workers=0, shuffle=True)
itr = iter(dl)

In [6]:
mesh, target  = next(itr)
with torch.no_grad():
    #equiformer_pred, _ = equiformer(mesh.cuda())
    o2s_pred, w = o2s(mesh.cuda())

In [7]:
import matplotlib.animation as animation

fig = plt.figure(layout='constrained', figsize=(12, 8), dpi=72)
subfigs = fig.subfigures(1, 2)

moviewriter = animation.FFMpegWriter(fps=30)
moviewriter.setup(fig, 'compare.mp4', dpi=72)

b = 0
num_frames = target.size(1)
for frame in range(num_frames):
    if frame % 10 == 0:
        print(f"frame={frame}")
    
    plotting.plot_spherical_fn(
        o2s_pred[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[0], 
        vmin=target[b].min(),
        vmax=target[b].max(),
        title=f"R={frame}",
        colorbar=False
    )
    if False:
        plotting.plot_spherical_fn(
            equiformer_pred[b,frame].cpu().squeeze().numpy(), 
            fig=subfigs[1], 
            vmin=target[b].min(),
            vmax=target[b].max(),
            title=f"R={frame}",
            colorbar=False
        )
    plotting.plot_spherical_fn(
        target[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[1],
        vmin=target[b].min(),
        vmax=target[b].max(),
        title=f"R={frame}",
        colorbar=False
    )

    plt.draw()
    moviewriter.grab_frame()
    subfigs[0].clear()
    subfigs[1].clear()
    #subfigs[2].clear()

moviewriter.finish()
plt.close()

frame=0
frame=10
frame=20
frame=30
frame=40
frame=50
frame=60


<video controls src="compare.mp4" />

<video controls src="compare.mp4" />

In [None]:
error = torch.abs(target.cpu() - o2s_pred.cpu())

In [None]:
import matplotlib.animation as animation

fig = plt.figure(layout='constrained', figsize=(12, 8), dpi=72)
subfigs = fig.subfigures(1, 3)

moviewriter = animation.FFMpegWriter(fps=30)
moviewriter.setup(fig, 'error.mp4', dpi=72)

b = 0
num_frames = target.size(1)
for frame in range(num_frames):
    if frame % 10 == 0:
        print(f"frame={frame}")
    
    plotting.plot_spherical_fn(
        o2s_pred[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[0], 
        vmin=target[b].min(),
        vmax=target[b].max(),
        title=f"R={frame}"
    )
    plotting.plot_spherical_fn(
        target[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[1], 
        vmin=target[b].min(),
        vmax=target[b].max(),
        title=f"R={frame}"
    )
    plotting.plot_spherical_fn(
        error[b,frame].cpu().squeeze().numpy(), 
        fig=subfigs[2],
        vmin=target[b].min(),
        vmax=target[b].max(),
        title=f"R={frame}"
    )

    plt.draw()
    moviewriter.grab_frame()
    subfigs[0].clear()
    subfigs[1].clear()
    subfigs[2].clear()

moviewriter.finish()
plt.close()

<video controls src="error.mp4" />

In [None]:
v,i = w[0,33].abs().sort(descending=True)

In [None]:
v[:10]

In [None]:
i[:10]

In [None]:
[  0,   6,  20,  72, 156, 110,  42, 272, 506, 210]
[4.9247, 1.4585, 0.8119, 0.5424, 0.4051, 0.3920, 0.3726, 0.2634, 0.2194, 0.2064]