In [None]:
%cd ..

In [None]:
import datetime
import os
import time
from glob import glob
from pathlib import Path
from types import SimpleNamespace

import hydra
import matplotlib.pyplot as plt
import pandas as pd
import torch
from omegaconf import DictConfig, OmegaConf

from src.utils import register_resolvers

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cpu" if not torch.cuda.is_available() else "cuda"

In [None]:
torch.set_grad_enabled(False)

In [None]:
def get_version(hparam_path):
    assert hparam_path.parent.name.startswith("version_")
    version = hparam_path.parent.name
    _, version = version.rsplit("_")
    version = int(version)
    return version


def get_ctime(file):
    timestamp = file.stat().st_ctime
    dt = datetime.datetime.fromtimestamp(timestamp)
    return dt


def get_nearest_ctime(hparam_path, checkpoints):
    hparam_ctime = get_ctime(hparam_path)
    opts = [c for c in checkpoints if "last" in c.name]
    return opts[0]
    # return min(opts, key=lambda c: abs((get_ctime(c) - hparam_ctime).total_seconds()))


def get_latest(checkpoints):
    opts = [c for c in checkpoints if "last" in c.name]
    if len(opts) > 1:
        print(opts)
    return max(opts, key=lambda c: get_ctime(c))


def get_checkpoint(hparam_path):
    version = get_version(hparam_path)

    ckpt_dir = hparam_path.parent.parent.parent / "checkpoints"
    if not ckpt_dir.exists():
        return None

    checkpoints = ckpt_dir.glob("*.ckpt")
    # print(checkpoints)
    # return get_latest(checkpoints)
    return get_nearest_ctime(hparam_path, checkpoints)
    # if version == 0:
    #     checkpoints = [c for c in checkpoints if c.name == "last.ckpt"]
    #     if len(checkpoints) == 0:
    #         return None
    # elif version > 0:
    #     checkpoints = [c for c in checkpoints if c.name == f"last-v{version}.ckpt"]
    #     if len(checkpoints) == 0:
    #         checkpoints = [c for c in checkpoints if c.name == f"last.ckpt"]
    #         if len(checkpoints) == 0:
    #             return None

    # return checkpoints[-1]


def get_max_epoch(hparam_path):
    try:
        return len(pd.read_csv(hparam_path.parent / "metrics.csv"))
    except:
        return 0

In [None]:
log_dir = Path("logs")
hparam_files = list(log_dir.glob("**/hparams.yaml"))

hparam_files = [h for h in hparam_files if get_max_epoch(h) >= 900]
ckpt_files = [get_checkpoint(h) for h in hparam_files]

hparam_files = [h for h, c in zip(hparam_files, ckpt_files) if c is not None]
ckpt_files = [c for c in ckpt_files if c is not None]

print(len(hparam_files))

In [None]:
def get_model_cls(model_idx):
    from importlib import import_module

    target = OmegaConf.load(hparam_files[model_idx]).model._target_
    mod, cls = target.rsplit(".", 1)
    mod = import_module(mod)
    return getattr(mod, cls)


def load_model(model_idx, device):
    cls = get_model_cls(model_idx)
    model = cls.load_from_checkpoint(ckpt_files[model_idx], device=device)
    return model


def load_dataset(model_idx, device):
    target = OmegaConf.load(hparam_files[model_idx]).data
    dm = hydra.utils.instantiate(target)
    dm.device = device
    dm.setup("test")
    return dm.test_dataloader()

In [None]:
def extract_metadata(hparam_file):
    cfg = OmegaConf.load(hparam_file)
    md = SimpleNamespace()
    if cfg.model._target_.endswith("KSinFlowMatchingModule"):
        md.model_family = "flow"
    elif cfg.model._target_.endswith("KSinFeedForwardModule"):
        md.model_family = "mlp"
    else:
        raise ValueError("Unknown model family")

    if cfg.data._target_.endswith("FMDataModule"):
        md.data_family = "fm"
    elif cfg.data._target_.endswith("KSinDataModule"):
        md.data_family = "ksin"

    md.ot = cfg.model.get("optimal_transport", False)
    md.loss_fn = cfg.model.get("loss_fn", None)
    md.k = cfg.data.get("k", 0.0)
    md.algorithm = cfg.data.get("algorithm", None)
    md.sort = cfg.data.get("sort_frequencies", False)
    md.break_symmetry = cfg.data.get("break_symmetry", False)

    md.date = get_ctime(hparam_file)

    return md.__dict__


rows = [extract_metadata(h) for h in hparam_files]

In [None]:
df = pd.DataFrame(rows)

df[((df.k == 2) & (~df.sort) & (~df.break_symmetry) & (df.model_family == "flow"))]

In [None]:
model_idx = 8
model = load_model(model_idx, device)
dl = load_dataset(model_idx, device)

In [None]:
model

In [None]:
signal, params, render = next(iter(dl))
params = params[:1].expand(signal.shape[0], -1).clone()
signal = render(*params.chunk(2, -1), signal.shape[-1])

In [None]:
steps = 100
cfg = 5.0
old_data = False

print("Sampling...")
t = time.time()
sample, y, x = model._sample((signal, params, render), steps, cfg)
dur = time.time() - t
sps = steps / dur
print(f"Done. {steps} steps in {dur:.2f} seconds ({sps} steps/sec)")
if old_data:
    sample[..., : sample.shape[-1] // 2] = 2 * sample[..., : sample.shape[-1] // 2] / torch.pi - 1
    sample[..., sample.shape[-1] // 2 :] = 2 * sample[..., sample.shape[-1] // 2 :] - 1

yc = y.cpu()
sc = sample.cpu()
for i in range(y.shape[-1] // 2):
    for j in range(y.shape[-1] // 2):
        if i == j:
            continue

        plt.scatter(yc[:, i], yc[:, j], color="black")
for i in range(y.shape[-1] // 2):
    for j in range(y.shape[-1] // 2):
        if i == j:
            continue
        plt.scatter(sc[:, i], sc[:, j], marker="+", color="red", alpha=0.05)

plt.xlim(-1, 1)
plt.ylim(-1, 1)

In [None]:
x_ = render(*sample.chunk(2, -1), x.shape[-1])

X = torch.fft.rfft(x).abs().log10()
X_ = torch.fft.rfft(x_).abs().log10()

plt.plot(X[0].cpu())
plt.plot(X_.T.cpu(), alpha=0.01, color="red")
plt.show()

In [None]:
import ot as pot

In [None]:
def divmod(a, b):
    return torch.div(a, b, rounding_mode="floor"), torch.remainder(a, b)


def _sample_from_ot_map(x0: torch.Tensor, x1: torch.Tensor, z: torch.Tensor):
    batch_size = z.shape[0]
    a = pot.unif(x0.shape[0], type_as=x0)
    b = pot.unif(x1.shape[0], type_as=x1)
    costs = torch.cdist(x0, x1).sqrt()

    ot_map = pot.sinkhorn(
        a, b, costs, 0.1, method="sinkhorn", numItermax=1000, verbose=True, stopThr=1e-6
    )
    # ot_map = pot.emd(a, b, costs, numThreads=4)
    pi = ot_map.flatten().square()
    samples = torch.multinomial(pi, batch_size, replacement=True)

    i, j = divmod(samples, batch_size)

    x0 = x0[i]
    x1 = x1[j]
    z = z[j]

    return x0, x1, z

In [None]:
x0 = torch.randn(1024, 10, device=params.device)
x1 = torch.randn(1024, 10, device=params.device)
z = model.encoder(signal)
(x0 - x1).square().mean(dim=-1)

In [None]:
x0, x1, z = _sample_from_ot_map(x0, x1, z)
(x0 - x1).square().mean(dim=-1)

In [None]:
next(iter(dl))[1]