In [1]:
import os
if os.getcwd().endswith("experiments"):
    os.chdir("..")

os.getcwd()

'/home/lars/code/python/context-aware-flow-matching'

In [2]:
import torch
from torch.utils.data import random_split
from lightning import seed_everything

import numpy as np
import matplotlib.pyplot as plt

from matplotlib.animation import ArtistAnimation
from pathlib import Path
from tqdm import tqdm, trange
from lightning_trainable.utils import find_checkpoint

import src.utils as U
import src.visualization as viz

from src.datasets import ModelNet10Dataset
from src.models import Model

In [3]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f636a7ddcc0>

In [4]:
seed_everything(42)

Seed set to 42


42

In [5]:
dataset = ModelNet10Dataset("data/modelnet10")

Found non-empty ModelNet10Dataset in data/modelnet10, skipping download...


Pre-Loading Meshes: 100%|███████████████████████████████████████████████████████████████████| 4899/4899 [00:53<00:00, 92.24it/s]


In [6]:
train_data, val_data, test_data = random_split(dataset, [0.8, 0.1, 0.1])

In [7]:
cp = find_checkpoint()
cp = Path(cp)
cp

PosixPath('lightning_logs/version_10/checkpoints/last.ckpt')

In [8]:
hparams = Model.hparams_type.from_yaml(cp.parent.parent / "hparams.yaml")
model = Model.load_from_checkpoint(cp, hparams=hparams, train_data=train_data, val_data=val_data, test_data=test_data, map_location="cuda:0")
model = model.eval()

In [9]:
shapes = 1
duration = 8
fps = 60
points = 2048
chunk_size = 32

use_blender = True

features = 3
steps = duration * fps

In [10]:
euler = torch.Tensor([0, 0, 360])
t = torch.linspace(0, 1, steps + 1)[:-1]
euler = t[:, None] * euler.repeat(steps, 1)
euler = euler[:, None, None, :]
euler = euler.expand(-1, shapes, points, -1)
euler = euler.to(model.device)

euler.shape

torch.Size([192, 4, 2048, 3])

In [11]:
noise = model.sample_noise((shapes, points))
noise = noise.unsqueeze(0)
noise = U.expand_dim(noise, 0, steps)

noise = U.rotate(noise.flatten(0, 2), euler.flatten(0, 2))
noise = noise.reshape(steps, shapes, points, features)

In [12]:
with U.temporary_seed(42):
    samples = torch.stack([model.val_data[i][1] for i in torch.randperm(len(model.val_data))[:shapes]])
samples = samples.to(model.device)
embeddings = model.embed(samples)
embeddings = embeddings.unsqueeze(0)
embeddings = U.expand_dim(embeddings, 0, steps)

In [13]:
samples = torch.vmap(model.sample_from, chunk_size=chunk_size)(noise, embeddings, integrator="euler", steps=250, progress=True)
torch.save(samples, "samples/rotation.pt")

  full_bar = Bar(frac,
Solving ODE: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [07:02<00:00]
Solving ODE: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [07:09<00:00]
Solving ODE: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [07:06<00:00]
Solving ODE: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [07:05<00:00]
Solving ODE: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [07:16<00:00]
Solving ODE: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [07:10<00:00]


In [14]:
samples = torch.load("samples/rotation.pt", map_location="cpu").numpy()

In [15]:
plt.rcParams["animation.ffmpeg_path"] = "/usr/bin/ffmpeg"

for shape in trange(shapes):
    fig = plt.figure(figsize=(4, 4))
    artists = []

    for frame in trange(steps):
        if use_blender:
            ax = fig.add_subplot(111)
            artist = viz.scatter_bp(samples[frame, shape], ax=ax)
        else:
            ax = fig.add_subplot(111, projection="3d")
            artist = viz.scatter(samples[frame, shape], ax=ax)

        artists.append([artist])
    
    animation = ArtistAnimation(fig, artists, interval=int(1000 / fps), blit=False)
    animation.save(f"animations/rotation_{shape:06d}.mp4", fps=fps, dpi=200)
    plt.close()

  0%|                                                                                                     | 0/4 [00:00<?, ?it/s]
  0%|                                                                                                   | 0/192 [00:00<?, ?it/s][A
  1%|▍                                                                                          | 1/192 [00:03<10:10,  3.20s/it][A
  1%|▉                                                                                          | 2/192 [00:06<09:24,  2.97s/it][A
  2%|█▍                                                                                         | 3/192 [00:08<09:09,  2.91s/it][A
  2%|█▉                                                                                         | 4/192 [00:11<09:00,  2.87s/it][A
  3%|██▎                                                                                        | 5/192 [00:14<08:53,  2.85s/it][A
  3%|██▊                                                                       

Saved: '/home/lars/code/python/context-aware-flow-matching/_.png'
 Time: 00:02.80 (Saving: 00:00.03)




  8%|███████                                                                                   | 15/192 [00:43<08:30,  2.89s/it][A
  8%|███████▌                                                                                  | 16/192 [00:46<08:28,  2.89s/it][A
  9%|███████▉                                                                                  | 17/192 [00:49<08:35,  2.94s/it][A
  9%|████████▍                                                                                 | 18/192 [00:52<08:36,  2.97s/it][A
 10%|████████▉                                                                                 | 19/192 [00:55<08:36,  2.99s/it][A
 10%|█████████▍                                                                                | 20/192 [00:58<08:31,  2.97s/it][A
 11%|█████████▊                                                                                | 21/192 [01:01<08:28,  2.97s/it][A
 11%|██████████▎                                                           

Saved: '/home/lars/code/python/context-aware-flow-matching/_.png'
 Time: 00:02.78 (Saving: 00:00.03)




 16%|██████████████                                                                            | 30/192 [01:22<07:24,  2.74s/it][A
 16%|██████████████▌                                                                           | 31/192 [01:25<07:22,  2.75s/it][A
 17%|███████████████                                                                           | 32/192 [01:28<07:16,  2.73s/it][A
 17%|███████████████▍                                                                          | 33/192 [01:31<07:13,  2.73s/it][A
 18%|███████████████▉                                                                          | 34/192 [01:33<07:14,  2.75s/it][A
 18%|████████████████▍                                                                         | 35/192 [01:36<07:09,  2.74s/it][A
 19%|████████████████▉                                                                         | 36/192 [01:39<07:04,  2.72s/it][A
 19%|█████████████████▎                                                    