In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim.adamw


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ForwardSDE(nn.Module):
    def __init__(self, dim=2):
        super().__init__()
        self.dim = dim

    def alpha(self, t):
        return t

    def beta(self, t):
        return torch.sqrt(1 - t)

    # Forward process: adds noise to data
    def forward_step(self, x0, t):
        alpha_t = self.alpha(t)  # (bs, 1)
        beta_t = self.beta(t)    # (bs, 1)
        noise = torch.randn_like(x0).to(device)
        x_t = alpha_t * x0 + beta_t * noise
        return x_t

    def score(self, x0, t):
        x_t = self.forward_step(x0, t)
        return -(x_t - self.alpha(t) * x0) / self.beta(t)**2

    def sample_xt(self, x0, t):
        return self.forward_step(x0, t)
    

class ApproxScore(nn.Module):
    def __init__(self, dim=2, hiddens=[64, 64, 64, 64]):
        super().__init__()
        layers = []
        dims = [dim + 1] + hiddens + [dim]
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:
                layers.append(nn.SiLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x, t):
        xt = torch.cat([x, t], dim=-1)
        return self.net(xt)

class Trainer(nn.Module):
    def __init__(self, sde, score_model):
        super().__init__()
        self.sde = sde
        self.score_model = score_model

    def get_train_loss(self, x_0):
        batch_size = x_0.shape[0]
        t = torch.rand(batch_size, 1).to(device)  # (bs, 1)
        x_t = self.sde.forward_operator(x_0, t)  # (bs, dim)
        pred_score = self.score_model(x_t, t)  # (bs, dim)
        true_score = self.sde.conditional_score(x_t, x_0, t)  # (bs, dim)
        return torch.abs(pred_score - true_score).mean()

    def train(self, x_0, num_epochs, lr):
        self.score_model.to(device)
        optimizer = torch.optim.AdamW(self.score_model.parameters(), lr=lr)
        self.score_model.train()
        for epoch in tqdm(range(num_epochs), desc="Training"):
            optimizer.zero_grad()
            loss = self.get_train_loss(x_0)
            loss.backward()
            optimizer.step()
        self.score_model.eval()   


def BackwardProcess(score_model, num_samples=1000, num_timesteps=300, sigma=2.0, dim=2):
    """Backward process: Sample using reverse SDE with learned score"""
    x_t = torch.randn(num_samples, dim).to(device)  # Start from noise
    ts = torch.linspace(1.0, 0.0, num_timesteps + 1).to(device)
    dt = -1.0 / num_timesteps
    for t in ts[:-1]:
        t = t.view(1, 1).expand(num_samples, 1)
        score = score_model(x_t, t)
        drift = -0.5 * sigma**2 * score  # Reverse SDE drift
        noise = sigma * torch.randn_like(x_t) * (-dt).sqrt()
        x_t = x_t + drift * dt + noise
    return x_t         