In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda")

torch.manual_seed(0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

def make_projection(D, d, device=device):
    A = torch.randn(D, d, device="cpu")
    Q, _ = torch.linalg.qr(A)  
    Q = Q.to(device)
    return Q  

def sample_underlying_2d(n_points):
    theta = np.linspace(0, 4 * np.pi, n_points)
    r = theta / (4 * np.pi) * 2.0  

    x = r * np.cos(theta)
    y = r * np.sin(theta)

    pts = np.stack([x, y], axis=1)

    pts += 0.02 * np.random.randn(*pts.shape)
    pts = torch.from_numpy(pts).float()
    return pts   

In [3]:
class MLP5(nn.Module):
    def __init__(self, x_dim, hidden_dim, out_dim, t_dim=None):
        """
        x_dim: D  (x is [B, D])
        t_dim: D  (t is [B, D]) by default
        """
        super().__init__()
        self.x_dim = x_dim
        self.t_dim = x_dim if t_dim is None else t_dim

        in_dim = self.x_dim + self.t_dim

        layers = []
        dims = [in_dim] + [hidden_dim] * 5 + [out_dim]
        for i in range(len(dims) - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.net = nn.Sequential(*layers)

    def forward(self, x, t):
        """
        x: [B, D]
        t: [B, D]  (per-dim time)
           (also allows [B] or [B,1], which will be broadcast to [B, t_dim])
        """
        if t.dim() == 1:
            t = t.unsqueeze(-1)  # [B, 1]

        if t.dim() == 2 and t.shape[1] == 1 and self.t_dim != 1:
            # broadcast scalar time to per-dim time if user passes [B,1]
            t = t.expand(-1, self.t_dim)  # [B, t_dim]

        assert x.dim() == 2 and x.shape[1] == self.x_dim, f"x should be [B, {self.x_dim}]"
        assert t.dim() == 2 and t.shape[1] == self.t_dim, f"t should be [B, {self.t_dim}]"

        t = t.to(dtype=x.dtype, device=x.device)
        inp = torch.cat([x, t], dim=-1)  # [B, D + t_dim]
        return self.net(inp)

In [4]:
def train_toy(
    D=16,
    d=2,
    target_type="x",
    n_samples=20000,
    batch_size=1024,
    epochs=500,
    lr=1e-3,
):
    P = make_projection(D, d)  # [D, 2]
    x_hat = sample_underlying_2d(n_samples).to(device)  # [N, 2]
    x = x_hat @ P.t()  
    
    sigma = x.std() / 3.0
    print(f"Data std: {x.std().item():.4f}")
    print(f"Using sigma: {sigma.item():.4f}")
    print(f"Data shape: {x.shape}")

    dataset = TensorDataset(x)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    model = MLP5(x_dim=D, hidden_dim=256, out_dim=D, t_dim=D).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in tqdm(range(epochs), desc=f"Training D={D}, target={target_type}"):
        for step, (x_batch,) in enumerate(loader):
            x_1 = x_batch.to(device)  # [B, D]

            B = x_1.size(0)
            x_1 = x_1 / sigma

            #  t ~ Uniform(0,1), asychronous time for each dim
            t = torch.rand((B, D), device=device)
            
            x_0 = torch.randn_like(x_1)
            x_t = t * x_1 + (1 - t) * x_0
            # [B, 2] * [B, 2] + [B, 2] * [B, 2]
            
            model_pred = model(x_t, t)

            if target_type == "data":
                dnorm = torch.clamp(1. - t, min=0.05)
                v_target = (x_1 - x_t) / dnorm
                v_pred = (model_pred - x_t) / dnorm
                loss = ((v_target - v_pred) ** 2).mean()

            elif target_type == "v":
                v_target = x_1 - x_0
                loss = ((v_target - model_pred) ** 2).mean()

            opt.zero_grad()
            loss.backward()
            opt.step()
        
        if (epoch + 1) % 50 == 0 or epoch == 0:
            print(f"[D={D}] Epoch {epoch + 1}/{epochs} | {target_type}-prediction loss: {loss.item():.4f}")

    return model, P, x_hat, x, sigma


def show_point(x, P, x_hat_true, target_type, D, cur_step=None):
    
    pred_2d = x @ P  # [N,2]

    x_hat_np = x_hat_true.cpu().numpy()
    pred_2d_np = pred_2d.cpu().numpy()

    plt.figure(figsize=(5, 5))
    plt.scatter(x_hat_np[:, 0], x_hat_np[:, 1], s=5, alpha=0.3, label="True 2D data")
    plt.scatter(pred_2d_np[:, 0], pred_2d_np[:, 1], s=5, alpha=0.7, label=f"Generated ({target_type}-pred)")
    plt.legend()
    plt.title(f"D={D}, target={target_type}")
    plt.axis("equal")
    plt.tight_layout()
    save_filename = f"toy_D{D}_target_{target_type}_step_{cur_step}.png" if cur_step is not None else f"toy_D{D}_target_{target_type}.png"
    plt.savefig(save_filename, dpi=200)


def visualize_2d(model, P, n_points=2000, target_type="x", steps=250, x_true=None, sigma=1.0):
    
    model.eval()
    P = P.to(device)
    D, d = P.shape  # P: [D,2]

    x = torch.randn(n_points, D, device=device)
    dt = 1.0 / steps

    for i in range(steps):
        with torch.no_grad():
            t = torch.full((n_points, D), i * dt, device=device) 
            x_t = x                                            

            pred = model(x_t, t)
            
            if target_type == "data":
                vp = (pred - x_t) / (1. - t)   
            elif target_type == "data_scaled":
                vp = pred / (1. - t)
            elif target_type == "v":
                vp = pred
            
            x = x_t + dt * vp  
            # if i % 50 == 0 or i == 1:
            #     show_point((x_t - i * dt * vp) * sigma, P, x_hat_true, target_type, D, cur_step=i)

    show_point(x * sigma, P, x_true, target_type, D)


In [6]:
Ds = [4, 16, 512]
target_types = ["data", "v"]

for D in Ds:
    for tt in target_types:
        print(f"\n=== Training D={D}, target={tt} ===")
        model, P, x_hat, x, sigma = train_toy(
            D=D,
            d=2,
            target_type=tt,
            n_samples=20000,
            batch_size=1024,
            epochs=500, 
            lr=1e-3,
        )
        visualize_2d(model, P, n_points=20000, target_type=tt, x_true=x_hat, sigma=sigma)


=== Training D=4, target=data ===


AssertionError: Torch not compiled with CUDA enabled