In [32]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

from pathlib import Path
root_path = Path.cwd().parent.parent.parent.absolute()
import sys
current_dir = Path.cwd()
repo_root = current_dir.parent.parent.parent  # Go up 3 levels
pdpo_path = str(repo_root.absolute())
# Add repository root to path
sys.path.insert(0, pdpo_path)




In [33]:
import jax
import jax.numpy as jnp
import jax.random as jrn
from jax import jit, vmap
from flax import nnx
import optax
import numpy as np
import matplotlib.pyplot as plt
import seaborn
from typing import Dict, Tuple, Optional
import time

In [34]:
# Set plot style
plt.style.use('default')

# JAX Configuration
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

device = jax.devices()[0]
jax.default_device(device)




JAX version: 0.7.0
Available devices: [CudaDevice(id=0)]


<jax._src.config.StateContextManager at 0x7f40b0639db0>

In [35]:
# Set random seed for reproducibility
SEED = 1
key = jrn.PRNGKey(SEED)
print(f"Random seed set to: {SEED}")

Random seed set to: 1


In [36]:
from pdpo.generative.models.matching_methods import FlowMatching,StochasticInterpolant
from pdpo.models.builder import create_model,create_mlp,create_mlp_time_embedding
from pdpo.data.toy_datasets import inf_train_gen
from pdpo.core.types import ModelParams, SampleArray, TimeArray
from pdpo.ode.solvers import MidpointSolver


In [37]:
# Test configuration
CONFIG = {
    'model': {
        'input_dim': 2,
        'hidden_dim': 128,
        'num_layers': 4,
        'activation': 'relu',
        'time_varying': True
    },
    'training': {
        'batch_size': 1000,
        'learning_rate': 1e-3,
        'num_epochs': 1000,
        'print_every': 100
    },
    'data': {
        'source': 'std_gaussian',
        'target': '8gaussians',
        'dim': 2
    }
}

In [38]:
print("Configuration loaded:")
for section, params in CONFIG.items():
    print(f"  {section}: {params}")

print("\n" + "="*50)
print("SETUP COMPLETE - Ready for Flow Matching tests")
print("="*50)

Configuration loaded:
  model: {'input_dim': 2, 'hidden_dim': 128, 'num_layers': 4, 'activation': 'relu', 'time_varying': True}
  training: {'batch_size': 1000, 'learning_rate': 0.001, 'num_epochs': 1000, 'print_every': 100}
  data: {'source': 'std_gaussian', 'target': '8gaussians', 'dim': 2}

SETUP COMPLETE - Ready for Flow Matching tests


In [39]:
architecture = [
    CONFIG['model']['input_dim'],      # 2 - input dimension  
    CONFIG['model']['num_layers'],     # 4 - number of layers
    CONFIG['model']['hidden_dim'],     # 64 - hidden layer width
    CONFIG['model']['activation'],      # 'relu' - activation function
    CONFIG['model']['time_varying']
]

print(f"MLP Architecture: {architecture}")
print(f"  - Input dimension: {architecture[0]}")
print(f"  - Hidden width: {architecture[1]}")  
print(f"  - Number of layers: {architecture[2]}")
print(f"  - Activation: {architecture[3]}")

MLP Architecture: [2, 4, 128, 'relu', True]
  - Input dimension: 2
  - Hidden width: 4
  - Number of layers: 128
  - Activation: relu


In [40]:
vf_model = create_mlp(input_size = architecture[0],
num_layers = architecture[1],
layer_width = architecture[2],
activation = architecture[3],
time_varying = architecture[4] # only for mlp
)


nnx.display(vf_model)

Traceback (most recent call last):
  File "/work/Sebas/miniconda3/envs/PDPO_jax/lib/python3.13/site-packages/treescope/renderers.py", line 225, in _render_subtree
    postprocessed_result = hook(
        node=node, path=path, node_renderer=render_without_this_hook
    )
  File "/work/Sebas/miniconda3/envs/PDPO_jax/lib/python3.13/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
    result = autoviz(node, path)
  File "/work/Sebas/miniconda3/envs/PDPO_jax/lib/python3.13/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
    jax.sharding.PositionalSharding
  File "/work/Sebas/miniconda3/envs/PDPO_jax/lib/python3.13/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0



In [41]:
optimizer = nnx.Optimizer(vf_model,optax.adam(CONFIG['training']['learning_rate']),wrt = nnx.Param)
scheduler = nnx

In [42]:
ode_solver = MidpointSolver

In [43]:
# model = FlowMatching(
#     vf_model = vf_model,
#     optimizer = optimizer,
#     ode_solver = ode_solver
# )

model = StochasticInterpolant(
    vf_model = vf_model,
    optimizer = optimizer,
    ode_solver = ode_solver
)

In [47]:
t = jnp.array(0.0)
x = jrn.normal(key, (10,2))

In [None]:
from pdpo.ode.solvers import eval_model
from pdpo.core.types import PRNGKeyArray


In [None]:
# Training setup
batch_size = CONFIG['training']['batch_size']
num_epochs = CONFIG['training']['num_epochs']
losses = []

print(f"Training config: {num_epochs} epochs, batch_size={batch_size}, lr={CONFIG['training']['learning_rate']}")

In [None]:
key_train = key
for epoch in range(num_epochs):
    # Generate training data
    key_train, key_batch = jrn.split(key_train)
    source_data = inf_train_gen(CONFIG['data']['source'], key_batch, batch_size, CONFIG['data']['dim'])
    target_data = inf_train_gen(CONFIG['data']['target'], key_batch, batch_size, CONFIG['data']['dim'])
    
    # Training step
    loss, metrics = model.training_step(key_batch, target_data,source_data)
    
    
    losses.append(float(loss))
    
    if epoch % CONFIG['training']['print_every'] == 0:
        print(f"Epoch {epoch:3d}: Loss = {loss:.6f}")
        generated = model.sample_trajectory(
            x0 = source_data,
            ode_solver = ode_solver
        )
        fig = plt.figure(figsize=(10,5))
        plt.scatter(source_data[:,0],source_data[:,1],s= 1, label = 'source')
        plt.scatter(target_data[:,0],target_data[:,1],s = 1, label = 'target')
        plt.scatter(generated[:,0],generated[:,1],s = 1, label = 'generate')
        plt.legend()
        plt.show()
        

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'{model.method_name} Loss Convergence')
plt.grid(True)
plt.show()

In [None]:
generated = model.sample_trajectory(
x0 = source_data,
ode_solver = ode_solver
)
fig = plt.figure(figsize=(10,5))
plt.scatter(source_data[:,0],source_data[:,1],s= 1, label = 'source')
plt.scatter(target_data[:,0],target_data[:,1],s = 1, label = 'target')
plt.scatter(generated[:,0],generated[:,1],s = 1, label = 'generate')
plt.legend()
plt.show()