In [None]:
import torch
import numpy as np
import zuko
import matplotlib.pyplot as plt
import torch.utils.data as data

from toy_dataset import get_data_sin_rand, get_data_sin

In [None]:
# 2. Get and normalize data
raw_data = get_data_sin_rand(0, 12)
data_mean = raw_data.mean(dim=0)
data_std = raw_data.std(dim=0)
data = (raw_data - data_mean) / data_std

In [None]:
# Plot original data
x = raw_data[:, 0].numpy()
y = raw_data[:, 1].numpy()
plt.figure(figsize=(8, 6))
plt.scatter(x, y, s=1, label='Training data')
plt.title("Training data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()

In [None]:
wrong_samples = get_data_sin(8, 10)
x_wrong = wrong_samples[:, 0].numpy()
y_wrong = wrong_samples[:, 1].numpy()
plt.figure(figsize=(8, 6))
plt.scatter(x, y, s=1, label='Training data')
plt.scatter(x_wrong, y_wrong, s=1, color='red', label='Wrong samples')
plt.title("Training data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()


In [None]:
# 3. Data loader
batch_size = 64
trainset = torch.utils.data.TensorDataset(data)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [None]:
# 4. Define conditional flow: p(y | x)
flow = zuko.flows.BPF(features=1, context=1, degree=20, hidden_features=(64, 64))

In [None]:
# 5. Training
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4)
loss_hist = []

for epoch in range(200):
    losses = []

    for (batch,) in trainloader:
        x_ctx = batch[:, 0:1]  # context (x)
        y_val = batch[:, 1:2]  # target (y)

        loss = -flow(x_ctx).log_prob(y_val).mean()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach())

    losses = torch.stack(losses)
    loss_hist.append(losses.mean().item())
    print(f"({epoch})", losses.mean().item(), "±", losses.std().item())

In [None]:
# 6. Plot loss
plt.figure(figsize=(8, 6))
plt.plot(loss_hist, label="Loss")
plt.xlabel("Epoch")
plt.ylabel("Negative Log-Likelihood")
plt.title("Training Loss Over Time")
plt.legend()
plt.show()

In [None]:
# 7. Sample from the learned distribution p(y | x)
x_values = torch.linspace(8, 10, 20).unsqueeze(1)
x_values_norm = (x_values - data_mean[0]) / data_std[0]  # normalize x
samples_norm = flow(x_values_norm).sample((100,))  # (100, 500, 1)
samples = samples_norm * data_std[1] + data_mean[1]  # denormalize y

In [None]:
# 8. Plot the sampled y's conditioned on x
plt.figure(figsize=(10, 6))
for i in range(100):  # 100 sampled functions
    plt.scatter(x_values.numpy(), samples[i, :, 0].numpy(), s=1, color='red', alpha=0.5, label="Sampled Data" if i == 0 else "")
plt.scatter(raw_data[:, 0].numpy(), raw_data[:, 1].numpy(), s=1, color='black', alpha=0.5, label="Original Data")
plt.title("Conditional Samples from Trained Flow: p(y | x)")
plt.xlabel("x")
plt.ylabel("y")
plt.legend(loc='upper left')
plt.show()

In [None]:
# 9. Plot conditional densities for fixed x values
import seaborn as sns
from scipy.stats import gaussian_kde

y_max = raw_data[:, 1].numpy().max()

x_slices = torch.linspace(0, 12, 12)
ys = torch.linspace(0, y_max, 500)

plt.figure(figsize=(10, 6))

# Normalize x and y
x_slices_norm = (x_slices - data_mean[0]) / data_std[0]
ys_norm = (ys - data_mean[1]) / data_std[1]

for x_val, x_val_norm in zip(x_slices, x_slices_norm):
    x_context = x_val_norm.unsqueeze(0).unsqueeze(0).expand(ys.shape[0], -1)  # shape (500, 1)
    y_eval = ys_norm.unsqueeze(1)  # shape (500, 1)

    with torch.no_grad():
        log_probs = flow(x_context).log_prob(y_eval)
        probs = torch.exp(log_probs).numpy()

    # Scale and shift to x-position for ridge plot
    probs_scaled = probs * 0.3  # scale for display
    plt.plot(probs_scaled + x_val.item(), ys, color='black')

    # Optional: draw vertical reference line
    plt.axvline(x_val.item(), color='black', linestyle='dashed', linewidth=0.5)

plt.scatter(raw_data[:, 0], raw_data[:, 1], alpha=0.2, s=10, label='Training Data')
plt.xlabel("x")
plt.ylabel("y")
plt.title("Conditional densities p(y|x)")
plt.legend(loc='upper left')
plt.show()
