In [None]:
import os
import sys

import jax
import jax.numpy as jnp
import numpy as np
import optax
from genot.models.model import GENOT
from genot.nets.nets import MLP_vector_field
import ott
import torch
import seaborn as sns
import genot
import jax
import ott
import diffrax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from genot.nets.nets import MLP_vector_field, MLP_marginal
import sklearn.preprocessing as pp
import scanpy as sc
from ott.solvers.linear import sinkhorn, acceleration
from genot.data.data import MixtureNormalSampler
from genot.plotting.plots import plot_1D_unbalanced
from eot_benchmark.gaussian_mixture_benchmark import (
    get_guassian_mixture_benchmark_ground_truth_sampler,
    get_guassian_mixture_benchmark_sampler,
    get_test_input_samples,
)
from eot_benchmark.metrics import calculate_cond_bw, compute_BW_UVP_by_gt_samples


class Loader:

    def __init__(self, sampler, batch_size):
        self.sampler = sampler
        self.batch_size = batch_size

    def __next__(self):
        return jnp.asarray(self.sampler.sample(self.batch_size).cpu())


out_dir = "./"

arguments = sys.argv
DIM = 64  # choose according to benchmark
EPS = 1.0  # choose according to benchmark
if EPS > 0.1:
    EPS = int(EPS)
BATCH_SIZE = 2048
K_NOISE_PER_X = 1
ITERATIONS = 100_000
LR = 1e-5
SEED = 0


iters = ITERATIONS  # this is just to log a few results

NUM_SAMPLES_cBVP = 1_000
NUM_SAMPLES_BVP = 100_000
GPU_DEVICE = 0

input_sampler = get_guassian_mixture_benchmark_sampler(
    input_or_target="input",
    dim=DIM,
    eps=EPS,
    batch_size=BATCH_SIZE,
    device=f"cuda:{GPU_DEVICE}",
    download=True,
)

target_sampler = get_guassian_mixture_benchmark_sampler(
    input_or_target="target",
    dim=DIM,
    eps=EPS,
    batch_size=BATCH_SIZE,
    device=f"cuda:{GPU_DEVICE}",
    download=True,
)

ground_truth_plan_sampler = get_guassian_mixture_benchmark_ground_truth_sampler(
    dim=DIM, eps=EPS, batch_size=BATCH_SIZE, device=f"cuda:{GPU_DEVICE}", download=True
)

neural_net = MLP_vector_field(DIM, 1024, 1024, 1024, n_frequencies=1024)
ot_solver = ott.solvers.linear.sinkhorn.Sinkhorn()
solver_latent_to_data = (
    ott.solvers.linear.sinkhorn.Sinkhorn() if K_NOISE_PER_X > 1 else None
)
optimizer = optax.adamw(learning_rate=LR, weight_decay=1e-10)

otfm = GENOT(
    neural_net,
    optimizer=optimizer,
    ot_solver=ot_solver,
    epsilon=EPS,
    input_dim=DIM,
    output_dim=DIM,
    iterations=iters,
    k_noise_per_x=K_NOISE_PER_X,
    solver_latent_to_data=solver_latent_to_data,
    latent_to_data_scale_cost="mean",
    seed=SEED,
)
s_sampler = Loader(input_sampler, BATCH_SIZE)
t_sampler = Loader(target_sampler, BATCH_SIZE)

bws = []
cond_bws = []

test_samples = get_test_input_samples(dim=DIM, device=f"cuda:{GPU_DEVICE}").cpu()
test_samples_repeated = np.repeat(test_samples[:, None, :], NUM_SAMPLES_cBVP, axis=1)

cpu_device = jax.devices('cpu')[0]
for i in range(1):
    otfm(s_sampler, t_sampler)
    
    predicted = jax.vmap(lambda *args, **kwargs: otfm.transport(*args, **kwargs)[0])(
        jnp.asarray(test_samples_repeated), seed=jnp.arange(NUM_SAMPLES_cBVP)
    )
    predicted_squeezed = jnp.squeeze(predicted)
    cond_bw = calculate_cond_bw(
        test_samples, torch.tensor(np.asarray(predicted_squeezed)), eps=EPS, dim=DIM
    )

    source_samples = np.asarray(input_sampler.sample(NUM_SAMPLES_BVP).cpu())
    predicted = otfm.transport(jnp.asarray(source_samples))
    predictions = torch.tensor(np.asarray(predicted[0][0, ...]))
    target_samples = target_sampler.sample(NUM_SAMPLES_BVP)
    bw = compute_BW_UVP_by_gt_samples(
        predictions.cpu().numpy(), target_samples.cpu().numpy()
    )

    bws.append(bw)
    cond_bws.append(cond_bw)

with open(
    os.path.join(
        out_dir,
        f"{DIM}_{EPS}_{BATCH_SIZE}_{K_NOISE_PER_X}_{ITERATIONS}_{LR}_{SEED}_report.npy",
    ),
    "wb",
) as f:
    pass#np.save(f, np.array([cond_bws, bws]))


  from .autonotebook import tqdm as notebook_tqdm
2024-10-20 12:07:25.262836: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-20 12:07:25.407540: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-20 12:07:25.466017: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Downloading...
From (original): https://drive.google.com/uc?id=1HNXbrkozARbz4r8fdFbjvPw8R74n1oiY
From (redirected): https://drive.google.com/uc?id=1HNXbrkozARbz4r8fdFbjvPw8R74n1oiY&confirm=t&uuid=3c82ad63-150b-4f15-8f30-f39fda2ba45c
To: /home/icb/dominik.klein/eot_benchmark_data/gaussian_mixture_benchmark_data.zip

In [2]:
jax.devices()

[CudaDevice(id=0)]

In [4]:
otfm

<genot.models.model.GENOT at 0x7f4494113150>