In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from dataclasses import dataclass
import argparse
# we will read in data with pandas frame
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import math

from scipy.spatial import cKDTree
from astropy.coordinates import SkyCoord
import astropy.units as u

KeyboardInterrupt: 

In [None]:
# define properties for plotting
import matplotlib.cm as cm
from matplotlib import rcParams
from cycler import cycler

def rgb(r,g,b):
    return (float(r)/256.,float(g)/256.,float(b)/256.)

cb2 = [rgb(31,120,180), rgb(255,127,0), rgb(51,160,44), rgb(227,26,28), \
       rgb(166,206,227), rgb(253,191,111), rgb(178,223,138), rgb(251,154,153)]

rcParams['figure.figsize'] = (9,7)
rcParams['figure.dpi'] = 50

rcParams['lines.linewidth'] = 2

rcParams['axes.prop_cycle'] = cycler('color', cb2)
rcParams['axes.facecolor'] = 'white'
rcParams['axes.grid'] = False

rcParams['patch.facecolor'] = cb2[0]
rcParams['patch.edgecolor'] = 'white'

rcParams['font.size'] = 23
rcParams['font.weight'] = 300

In [None]:
if torch.cuda.is_available():
  print("ran on GPU")
from google.colab import drive
drive.mount('/content/drive')

In [None]:
phase_sample = np.genfromtxt('/content/drive/MyDrive/sample_0.7_75.txt')
R_gal = phase_sample[:,0]
X_gal = phase_sample[:,1]
Y_gal = phase_sample[:,2]
Z_gal = phase_sample[:,3]
Vx_gal = phase_sample[:,4]
Vy_gal = phase_sample[:,5]
Vz_gal = phase_sample[:,6]
weight_star = phase_sample[:,7]

def vel_cartesian_to_galactic(pos,vel, err_p=0):

  r = (pos[:,0]**2 + pos[:,1]**2 + pos[:,2]**2)**0.5
  theta = np.arcsin(pos[:,2]/r)
  phi = np.arctan2(pos[:,1], pos[:,0])

  vr = np.cos(theta)*np.cos(phi)*vel[:,0] \
        +  np.cos(theta)*np.sin(phi)*vel[:,1] \
        +  np.sin(theta)*vel[:,2]

  v_theta = -np.sin(theta)*np.cos(phi)*vel[:,0]\
   - np.sin(theta)*np.sin(phi)*vel[:,1]\
            + np.cos(theta)*vel[:,2]

  v_phi = -np.sin(phi)*vel[:,0] + np.cos(phi)*vel[:,1]

  return vr, v_theta, v_phi

pos_gal = np.array([X_gal, Y_gal, Z_gal]).T
vel_gal = np.array([Vx_gal, Vy_gal, Vz_gal]).T

N_small = 50000
idx = np.random.choice(len(R_gal), size=N_small, replace=False)
pos_small = pos_gal[idx]
vel_small = vel_gal[idx]

R_small = np.linalg.norm(pos_small, axis=1)
R_slice_min = 60.0
R_slice_max = 90.0
slice_mask = (R_small >= R_slice_min) & (R_small <= R_slice_max)

pos_small = pos_small[slice_mask]
vel_small = vel_small[slice_mask]
R_small = R_small[slice_mask]

print(f"Selected {pos_small.shape[0]} stars with {R_slice_min:.0f} <= R <= {R_slice_max:.0f} from {len(idx)} samples.")

v_r_small, v_theta_small, v_phi_small = vel_cartesian_to_galactic(pos_small, vel_small)


In [None]:
def set_seed(seed: int = 42):
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)

In [None]:
@dataclass
class NormStats:
  center: torch.Tensor  # [3]
  scale: torch.Tensor         

  def _scale_like(self, x: torch.Tensor) -> torch.Tensor:
    # Ensure scale is a tensor on the same device/dtype as x
    return torch.as_tensor(self.scale, device=x.device, dtype=x.dtype)

  def _center_like(self, x: torch.Tensor) -> torch.Tensor:
    # Ensure center is a tensor on the same device/dtype as x
    return torch.as_tensor(self.center, device=x.device, dtype=x.dtype)

  def normalize(self, x3: torch.Tensor) -> torch.Tensor:
    s = self._scale_like(x3)
    c = self._center_like(x3)
    return (x3 - c) / (s + 1e-8)

  def denormalize(self, x3_norm: torch.Tensor) -> torch.Tensor:
    s = self._scale_like(x3_norm)
    c = self._center_like(x3_norm)
    return x3_norm * (s + 1e-8) + c

  def to_physical_velocity(self, v_norm: torch.Tensor) -> torch.Tensor:
    """Map a velocity expressed in normalized coordinates back to physical units."""
    s = self._scale_like(v_norm)
    return v_norm * (s + 1e-8)


def isotropic_stats(x: torch.Tensor, pct: float = 95.0) -> NormStats:
  center = x.mean(dim=0)
  #center = torch.zeros(3)
  r = torch.sqrt(((x - center)**2).sum(dim=1))
  # use a robust single scale (95th percentile or RMS)
  scale = torch.quantile(r, pct/100.0)
  return NormStats(center=center, scale=scale)


class LBVDataset(Dataset):
  """Simple dataset for Nx4 array with columns [x, y, z]."""
  def __init__(self, data_np: np.ndarray, norm: NormStats | None = None):
    assert data_np.ndim == 2 and data_np.shape[1] == 3, "Expect Nx3 array [x, y, z]"
    x = torch.tensor(data_np, dtype=torch.float32)
    self.norm = isotropic_stats(x) if norm is None else norm
    self.x = self.norm.normalize(x)

  def __len__(self):
    return self.x.shape[0]


  def __getitem__(self, idx):
    return self.x[idx]




In [None]:
# ------------------------------
# Model: u_θ(x,t)
# ------------------------------


class Flow(nn.Module):

  def __init__(self, hidden: int = 512, depth: int = 3, act: str = "silu"):

    super().__init__()

    Act = nn.ELU if act.lower() == "elu" else nn.SiLU

    in_dim = 3 + 1 + 1  # (x,y,z) + r + t

    out_dim = 3         # predict dx/dt on (x,y,z)

    layers = [nn.Linear(in_dim, hidden), Act()]

    for _ in range(depth - 1):

      layers += [nn.Linear(hidden, hidden), Act()]

    layers += [nn.Linear(hidden, out_dim)]

    self.net = nn.Sequential(*layers)

    self.hidden = hidden

    self.depth = depth

    self.act_name = act





  def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:

    # xt: [B,3]; t: [B,1]

    rt = torch.sqrt((xt ** 2).sum(dim=1, keepdim=True))

    return self.net(torch.cat([xt, rt, t], dim=-1))





  @torch.no_grad()

  def step_heun(self, x: torch.Tensor, t0: float, t1: float, device: torch.device) -> torch.Tensor:

    """One Heun (midpoint/RK2) step from t0 to t1 for the ODE x' = u_θ(x,t)."""

    B = x.shape[0]

    t0_t = torch.full((B, 1), t0, device=device)

    t1_t = torch.full((B, 1), t1, device=device)

    dt = (t1 - t0)

    k1 = self.forward(x, t0_t)

    x_mid = x + 0.5 * dt * k1

    k2 = self.forward(x_mid, 0.5 * (t0_t + t1_t))

    return x + dt * k2





# ------------------------------

# Training (straight-line CFM)

# ------------------------------




def make_r_weight_bins(train_ds, bins: int = 64):

  """Build (edges, inv_counts) from the *normalized* XYZ in train_ds.

  Returns numpy arrays: edges (len=bins+1), inv (len=bins) normalized to mean≈1.

  """

  with torch.no_grad():

      x = train_ds.x  # [N, D], D=3/4/5 ... first 3 are XYZ

      xyz = x[:, :3]

      r = torch.sqrt((xyz ** 2).sum(dim=1)).cpu().numpy()

  hist, edges = np.histogram(r, bins=bins, density=False)

  hist = hist.astype(np.float64) + 1e-6  # Laplace smoothing to avoid zeros

  inv = 1.0 / hist

  inv /= inv.mean()  # keep average weight ~1

  return edges, inv




def r_weights_from_edges(x_batch: torch.Tensor, edges: np.ndarray, inv: np.ndarray,

                         cap: float | None = 5.0) -> torch.Tensor:

  """Map x_batch (normalized coords, first 3 dims XYZ) to inverse-frequency weights by R.

  Returns weights as a tensor on x_batch.device.

  """

  xyz = x_batch[:, :3]

  r = torch.sqrt((xyz ** 2).sum(dim=1)).detach().cpu().numpy()

  idx = np.clip(np.digitize(r, edges) - 1, 0, len(inv) - 1)

  w = inv[idx]

  if cap is not None:

      w = np.minimum(w, float(cap))

  return torch.tensor(w, dtype=torch.float32, device=x_batch.device)




  def train_cfm(

  model: Flow,

  loader: DataLoader,

  epochs: int = 200,

  lr: float = 2e-4,

  wd: float = 2e-4,

  device: str = "cuda",

  ema_decay: float = 0.999,

  max_batches_per_epoch: int | None = None,

  # R-weighting controls (set edges/inv=None to disable)

  edges: np.ndarray | None = None,

  inv: np.ndarray | None = None,

  w_cap: float | None = 5.0,

  Rmin_phys: float | None = None,        # e.g. 60.0

  Rmax_phys: float | None = None,        # e.g. 90.0

  use_penalty: bool = True,

  lam: float = 0.05,                     # weight for the radial velocity penalty

  norm: NormStats | None = None,         # pass the same norm you used for LBVDataset

  radius_sampler=None,                   # callable returning radii in physical space

  align_prob: float = 0.8,

  dir_jitter: float = 0.05,

  penalty_warmup: int = 10,

  radius_jitter: float = 0.0,

  penalty_margin: float = 1.5,

):

  if radius_sampler is None:

    raise ValueError("train_cfm requires a radius_sampler callable returning radii in physical space.")

  if norm is None:

    raise ValueError("train_cfm requires normalization statistics (norm).")

  device = torch.device(device if torch.cuda.is_available() else "cpu")

  model.to(device)

  opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

  try:
    steps_per_epoch = len(loader)
  except TypeError:
    steps_per_epoch = 0
  if max_batches_per_epoch is not None and steps_per_epoch:
    steps_per_epoch = min(steps_per_epoch, max_batches_per_epoch)
  elif max_batches_per_epoch is not None and not steps_per_epoch:
    steps_per_epoch = max_batches_per_epoch
  steps_per_epoch = max(1, steps_per_epoch)
  total_steps = max(1, epochs * steps_per_epoch)
  sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)

  current_lr = sched.get_last_lr()[0]
  global_step = 0

  # EMA shadow

  ema_model = Flow(hidden=model.hidden, depth=model.depth, act=model.act_name)

  ema_model.load_state_dict(model.state_dict())

  ema_model.to(device)

  for p in ema_model.parameters():

    p.requires_grad_(False)



  def ema_update():

    with torch.no_grad():

      for p, pe in zip(model.parameters(), ema_model.parameters()):

        pe.data.mul_(ema_decay).add_(p.data, alpha=1.0 - ema_decay)



  beta = torch.distributions.Beta(2.0, 2.0)  # t ~ Beta(2,2)

  align_prob = float(max(0.0, min(1.0, align_prob)))

  penalty_warmup = max(1, penalty_warmup)



  model.train()



  for ep in range(1, epochs + 1):

    running = 0.0

    nbatches = 0

    lam_eff = lam if use_penalty else 0.0

    if use_penalty:

      lam_eff = lam * min(1.0, ep / float(penalty_warmup))



    for bidx, x1 in enumerate(loader):

      x1 = x1.to(device)

      B = x1.shape[0]



      with torch.no_grad():

        x1_phys = norm.denormalize(x1)

        dirs = x1_phys / (torch.linalg.norm(x1_phys, dim=1, keepdim=True) + 1e-12)

        r0_phys = radius_sampler(B, device=device, jitter=radius_jitter)

        mask_align = torch.rand(B, device=device) < align_prob

        x0_phys = torch.empty_like(x1_phys)

        if mask_align.any():

          dirs_align = dirs.clone()

          if dir_jitter > 0.0:

            eps = torch.randn_like(dirs_align)

            eps -= (eps * dirs_align).sum(dim=1, keepdim=True) * dirs_align

            dirs_align = dirs_align + dir_jitter * eps

            dirs_align = dirs_align / (dirs_align.norm(dim=1, keepdim=True) + 1e-12)

          x0_phys[mask_align] = dirs_align[mask_align] * r0_phys[mask_align]

        if (~mask_align).any():

          dirs_iso = torch.randn(((~mask_align).sum().item(), 3), device=device)

          dirs_iso = dirs_iso / (dirs_iso.norm(dim=1, keepdim=True) + 1e-12)

          x0_phys[~mask_align] = dirs_iso * r0_phys[~mask_align]

        if (Rmin_phys is not None) and (Rmax_phys is not None):

          r_curr = torch.linalg.norm(x0_phys, dim=1, keepdim=True)

          scale = torch.clamp(r_curr, min=Rmin_phys, max=Rmax_phys) / (r_curr + 1e-12)

          x0_phys = x0_phys * scale

        x0 = norm.normalize(x0_phys)



      # Sample t ∈ (0,1)

      t = beta.sample((B, 1)).to(device)         # t ~ Beta(2,2)



      # Straight line interpolation and target

      xt = (1.0 - t) * x0 + t * x1               # X_t

      target = (x1 - x0)                         # u* = X1 - X0 (t-constant)



      pred = model(xt, t)

      err  = ((pred - target) ** 2).sum(dim=1)   # [B]



      if edges is not None and inv is not None:

        w = r_weights_from_edges(x1, edges, inv, cap=w_cap)

        base_loss = (err * w).mean()

      else:

        base_loss = err.mean()



        pen_mean = torch.zeros(1, device=device)

        if use_penalty and (Rmin_phys is not None) and (Rmax_phys is not None):

          xt_phys = norm.denormalize(xt)

          r_xt_phys = torch.linalg.norm(xt_phys, dim=1, keepdim=True)

          pred_phys = norm.to_physical_velocity(pred)

          dirs_t = xt_phys / (r_xt_phys + 1e-12)

          v_radial = (pred_phys * dirs_t).sum(dim=1, keepdim=True)

          lower = Rmin_phys + penalty_margin

          upper = Rmax_phys - penalty_margin

          if not upper > lower:

            lower = Rmin_phys

            upper = Rmax_phys

          lower = torch.tensor(lower, device=device)

          upper = torch.tensor(upper, device=device)

          dist_low = F.relu(lower - r_xt_phys)

          dist_high = F.relu(r_xt_phys - upper)

          pen = dist_low * F.relu(-v_radial) + dist_high * F.relu(v_radial)

          pen_mean = pen.mean()

          total_loss = base_loss + lam_eff * pen_mean

        else:

          total_loss = base_loss



      opt.zero_grad(set_to_none=True)

      total_loss.backward()

      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

      opt.step()

      ema_update()

      global_step += 1
      if global_step <= total_steps:
        sched.step()
        current_lr = sched.get_last_lr()[0]

      running += total_loss.item()

      nbatches += 1



      if max_batches_per_epoch and (bidx + 1) >= max_batches_per_epoch:

        break



    if ep == 1 or ep % 10 == 0:

      print(f"[Epoch {ep:03d}] flow_mse={base_loss.item():.6f} pen={pen_mean.item():.6f} lam_eff={lam_eff:.3f} total={total_loss.item():.6f} lr={current_lr:.2e}")



  return ema_model


In [None]:
# ------------------------------
# Sampling
# ------------------------------
@torch.no_grad()
def sample(model: Flow, n: int, steps: int = 30, device: str = "cuda",
           radius_sampler=None, norm: NormStats | None = None,
           Rmin_phys: float | None = None, Rmax_phys: float | None = None,
           enforce_bounds: bool = True, radius_jitter: float = 0.0) -> torch.Tensor:
  if radius_sampler is None:
    raise ValueError("sample requires a radius_sampler callable returning radii in physical space.")
  if norm is None:
    raise ValueError("sample requires NormStats for normalization/denormalization.")
  device = torch.device(device if torch.cuda.is_available() else "cpu")
  model.eval().to(device)

  r0_phys = radius_sampler(n, device=device, jitter=radius_jitter)
  u = torch.randn(n, 3, device=device)
  u = u / (u.norm(dim=1, keepdim=True) + 1e-12)
  x_phys = u * r0_phys
  x = norm.normalize(x_phys)

  ts = 0.5*(1 - torch.cos(torch.linspace(0, math.pi, steps + 1, device=device)))
  for i in range(steps):
    x = model.step_heun(x, t0=float(ts[i].item()), t1=float(ts[i + 1].item()), device=device)

    if enforce_bounds and (Rmin_phys is not None) and (Rmax_phys is not None):
      x_phys = norm.denormalize(x)
      r_phys = torch.linalg.norm(x_phys, dim=1, keepdim=True)
      s = torch.clamp(r_phys, min=Rmin_phys, max=Rmax_phys) / (r_phys + 1e-12)
      x_phys = x_phys * s
      x = norm.normalize(x_phys)

  return x


In [None]:
# ------------------------------
# Main
# ------------------------------

# Define parameters directly for Colab
epochs = 80
batch_size = 512
#steps = 30
device = "cuda"
save_path = "samples_xyzvb.csv"
shuffle_data = False # Or True, depending on desired behavior

set_seed(42)

# Load data - using already loaded data
data_np = pos_small
assert data_np.ndim == 2 and data_np.shape[1] == 3, "Expect Nx3 array [X,Y,Z]"
train_np = data_np

# Fit normalization on train
tmp_ds = LBVDataset(train_np, norm=None)
norm = tmp_ds.norm
train_ds = LBVDataset(train_np, norm=norm)

with torch.no_grad():
  x_norm_all = train_ds.x                                 # [N,3] normalized xyz
  r_norm_all = torch.sqrt((x_norm_all**2).sum(dim=1)).cpu().numpy()     # normalized radii
  rmin_norm = float(r_norm_all.min())
  rmax_norm = float(r_norm_all.max())

print(f"Normalized radius range: r_min={rmin_norm:.4f}, r_max={rmax_norm:.4f}")

r_phys_all = np.linalg.norm(train_np, axis=1)
r_phys_sorted = np.sort(r_phys_all)
cdf_r = np.linspace(0.0, 1.0, len(r_phys_sorted))
Rmin_phys = float(R_slice_min)
Rmax_phys = float(R_slice_max)
print(f"Physical radius range: R_min={Rmin_phys:.2f}, R_max={Rmax_phys:.2f}")

def sample_r_from_data(B, device, jitter=0.0):
  base = np.interp(np.random.rand(B), cdf_r, r_phys_sorted)
  rs = base.copy()
  if jitter and jitter > 0:
    rs = base + jitter * np.random.randn(B)
    if (Rmin_phys is not None) and (Rmax_phys is not None):
      lower, upper = Rmin_phys, Rmax_phys
      mask = (rs < lower) | (rs > upper)
      attempts = 0
      while mask.any() and attempts < 8:
        rs[mask] = base[mask] + jitter * np.random.randn(mask.sum())
        mask = (rs < lower) | (rs > upper)
        attempts += 1
      if mask.any():
        rs[mask] = np.clip(rs[mask], lower, upper)
  return torch.tensor(rs, dtype=torch.float32, device=device).unsqueeze(1)

#edges, inv = make_r_weight_bins(train_ds, bins=64)
edges, inv = None, None

loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

# Model & training
model = Flow(hidden=768, depth=5, act="silu")


In [None]:
ema_model = train_cfm(
    model,
    loader,
    epochs=epochs,
    device=device,
    edges=edges,
    inv=inv,
    w_cap=1.5,
    Rmin_phys=Rmin_phys,
    Rmax_phys=Rmax_phys,
    use_penalty=True,
    lam=0.08,
    norm=norm,
    radius_sampler=sample_r_from_data,
    align_prob=0.9,
    dir_jitter=0.05,
    penalty_warmup=20,
    radius_jitter=0.4,
    penalty_margin=2.0,
)


In [None]:
# Sample in normalized space and denormalize
x_norm = sample(
    ema_model,
    n=50000,
    steps=300,
    device=device,
    radius_sampler=sample_r_from_data,
    norm=norm,
    Rmin_phys=Rmin_phys,
    Rmax_phys=Rmax_phys,
    enforce_bounds=True,
    radius_jitter=0.0,
)
x_denorm = norm.denormalize(x_norm.cpu())   # [N,3]

pos_out = x_denorm.numpy()


In [None]:
R_out = np.sqrt(pos_out[:,0]**2+pos_out[:,1]**2+pos_out[:,2]**2)


In [None]:
plt.hist(R_small, bins=100, density=True, alpha=0.5, label='True')
plt.hist(R_out, bins=100, density=True, alpha=0.5, label='Generated')
plt.xlabel('R')
plt.ylabel('Density')
plt.legend()
plt.axvline(R_slice_min, color='k', linestyle='--', linewidth=1)
plt.axvline(R_slice_max, color='k', linestyle='--', linewidth=1)
plt.show()
