In [3]:
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from tqdm import tqdm
import math
import numpy as np
from climb_conversion import ClimbsFeatureArray
# climbs = ClimbsFeatureArray(db_path="../data/storage.db").get_features_2d()

In [None]:
class ResidualBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        self.cond_proj = nn.Linear(cond_dim, out_channels*2)
        self.shortcut = nn.Conv1d(in_channels,out_channels,1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        h = self.conv1(x)
        h = self.norm1(h)

        gamma, beta = self.cond_proj(cond).unsqueeze(-1).chunk(2, dim=1)
        h = h*(1+gamma) + beta

        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act(h)

        return h + self.shortcut(x)
class ClimbingUNet1D(nn.Module):
    def __init__(self, in_channels=10, cond_dim=4, base_dim=64):
        super().__init__()
        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, base_dim),
            nn.SiLU(),
            nn.Linear(base_dim, base_dim),
        )
        
        # Input projection
        self.init_conv = nn.Conv1d(in_channels, base_dim, 3, padding=1)
        
        # Downsample
        self.down1 = ResidualBlock1D(base_dim, base_dim, base_dim + cond_dim)
        self.down2 = ResidualBlock1D(base_dim, base_dim*2, base_dim + cond_dim)
        
        # Bottleneck
        self.mid = ResidualBlock1D(base_dim*2, base_dim*2, base_dim + cond_dim)
        
        # Upsample
        self.up2 = ResidualBlock1D(base_dim*3, base_dim, base_dim + cond_dim) # Concat skip
        self.up1 = ResidualBlock1D(base_dim*2, base_dim, base_dim + cond_dim)
        
        # Output Heads
        # Head 1: Continuous noise prediction (Features 0-5)
        self.head_cont = nn.Conv1d(base_dim, 5, 1)
        # Head 2: Discrete Role logits (Features 6-9 -> 4 classes)
        self.head_disc = nn.Conv1d(base_dim, 5, 1)

    def forward(self, x, t, conditions):
        # 1. Process Time and Conditions
        # t is [B, 1], conditions is [B, 4]
        t_emb = self.time_mlp(t)
        # Concatenate time embedding with global conditions
        global_cond = torch.cat([t_emb, conditions], dim=1) 
        
        # 2. U-Net Backbone
        x = x.transpose(1, 2) # [B, L, C] -> [B, C, L]
        x1 = self.init_conv(x)
        
        x2 = self.down1(x1, global_cond)
        x3 = self.down2(F.max_pool1d(x2, 2), global_cond)
        
        x_mid = self.mid(x3, global_cond)
        
        x_up2 = F.interpolate(x_mid, scale_factor=2)
        x_up2 = torch.cat([x_up2, x2], dim=1) # Skip connection
        x_up2 = self.up2(x_up2, global_cond)
        
        x_up1 = torch.cat([x_up2, x1], dim=1)
        x_out = self.up1(x_up1, global_cond)
        
        # 3. Split Outputs
        pred_noise = self.head_cont(x_out)   # [B, 5, 20]
        pred_logits = self.head_disc(x_out)  # [B, 5, 20]
        
        return pred_noise.transpose(1, 2), pred_logits.transpose(1, 2)


# -------------------------------------------------------------------
# 2. The Hybrid Diffusion Wrapper
# -------------------------------------------------------------------
class HybridDDPM(nn.Module):
    def __init__(
        self,
        timesteps=1000
    ):
        super().__init__()
        self.model = ClimbingUNet1D()
        self.timesteps = timesteps
        
        # Feature indices
        self.cont_idx = slice(0, 5) # x, y, pull_x, pull_y, is_foot
        self.disc_idx = slice(5, 10) # 4 one-hot roles
    
    def _sine_alpha_bar(self, t):
        """Compute alpha bar using sine schedule."""
        a = torch.sin(t*torch.pi/2)**2
        return a.reshape((-1,1,1))
    
    def q_sample(self, x_start, t):
        """
        Hybrid Noise Injection:
        - Continuous: Add Gaussian Noise
        - Discrete: Apply Absorbing State (Masking)
        """
        B, L, C = x_start.shape
        a = self._sine_alpha_bar(t)
        
        # --- Part A: Continuous Features ---
        x_cont = x_start[:, :, self.cont_idx]
        noise = torch.randn_like(x_cont)
        
        # Gaussian Diffusion Formula: sqrt(a_bar)*x0 + sqrt(1-a_bar)*eps
        noisy_cont = torch.sqrt(a) * x_cont + torch.sqrt(1 - a) * noise
        
        # --- Part B: Discrete Features (Absorbing State) ---
        x_disc = x_start[:, :, self.disc_idx] # One-hot [B, L, 5]
        
        mask_prob = torch.rand(B,L) > a.reshape((-1,1))
        x_disc[mask_prob, :] = 0.0
        
        # Combine
        x_t = torch.cat([noisy_cont, x_disc], dim=2)
        
        return x_t, noise

    def loss(self, x_start, conditions):
        batch_size = x_start.shape[0]
        device = x_start.device
        t = torch.rand((batch_size, 1), device=device)

        x_t, noise_target = self.q_sample(x_start, t)
        
        pred_noise, pred_logits = self.model(x_t, t, conditions)
        loss_cont = F.mse_loss(pred_noise, noise_target)

        target_classes = torch.argmax(x_start[:, :, self.disc_idx], dim=2).reshape(-1)
        pred_logits = pred_logits.reshape((-1,5))
        loss_disc = F.cross_entropy(pred_logits, target_classes)
        
        return loss_cont + loss_disc

    @torch.no_grad()
    def generate(self, conditions, device):
        B = conditions.shape[0]
        L = 20
        # Initialize noisy data
        x_cont = torch.randn(B, L, 5, device=device)
        x_disc = torch.zeros(B, L, 5, device=device)
        
        # 2. Denoising Loop
        for i in range(self.timesteps):
            t = i/self.timesteps
            t_tensor = torch.full((B, 1), t, device=device, dtype=torch.float32)
            x_in = torch.cat([x_cont, x_disc], dim=2)

            # Predict
            x_cont, x_disc = self.model(x_in, t_tensor, conditions)
            if i==self.timesteps-1:
                break

            # Add noise back into the continuous features according to schedule
            a = self._sine_alpha_bar(t)
            x_cont = torch.sqrt(a) * x_cont + torch.sqrt(1-a) * torch.randn(B, L, 5, device=device)

            # Add noise back into logit predictions by randomly re-masking them according to schedule.
            mask_decision = torch.rand(B,L) > a.reshape((-1,1))
            x_disc[mask_decision,:] = 0.0

        return torch.cat([x_cont, x_disc], dim=2)

# -------------------------------------------------------------------
# 2. The Trainer
# -------------------------------------------------------------------
class HybridDDPMTrainer():
    """Trainer for the Hybrid DDPM Climb Diffusion Model"""
    def __init__(self, model, dataset, device='cpu'):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        self.device = torch.device(device)
        self.dataset = TensorDataset(*[torch.Tensor(arr) for arr in dataset])
    
    def train_model(self, epochs=500, batch_size=64):
        """Train the model for n epochs with batching"""
        batches = DataLoader(self.dataset, batch_size = batch_size, shuffle=True, drop_last=True)
        self.model.train()
        with tqdm(range(epochs)) as pbar:
            for epoch in pbar:
                total_loss = 0
                for x, c in batches:
                    x = x.to(self.device)
                    c = c.to(self.device)
                    self.optimizer.zero_grad()

                    # Calculate loss (Forward Diffusion + Prediction)
                    loss = self.model.loss(x,c)
                    loss.backward()
                    self.optimizer.step()

                    total_loss += loss.item()
                pbar.set_postfix_str(f"Epoch: {epoch}, Batches:{len(batches)} Total Loss: {total_loss:.4f}, Avg Loss: {total_loss/len(batches):.4f}")

In [None]:
climbs = ClimbsFeatureArray(db_path='../data/storage.db')
dataset = climbs.get_features_2d()

In [None]:
model = HybridDDPM()
trainer = HybridDDPMTrainer(
    model = model,
    dataset = dataset
)

In [None]:
trainer.train_model(epochs=1500)

  0%|          | 0/1500 [00:00<?, ?it/s]