In [1]:
from model import GaussianDiffusionTransformer
from utils.image_utils import render_and_save
from utils.diffusion_data_helper import denormalize_data, DiffusionScheduler
from utils.dataset_helper import create_dataloaders
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import os
import time

In [2]:
# --- Configuration ---
CONFIG = {
    "data_dir": "./data/small/",
    "output_dir": "./output/",
    "batch_size": 32,            # Small physical batch size
    "grad_accumulation": 4,     # Effective batch size = 192 (48 * 4)
    "model": {
        "input_dim": 8,
        "model_dim": 256,
        "n_heads": 4,
        "n_layers": 4,
    },
    "train": {
        "max_epochs": 3000,
        "base_lr": 5e-4,        # Slightly lower max LR for stability
        "warmup_epochs": 50,   # Warmup to prevent shock
        "clip_norm": 2,
    },
    "diffusion_steps": 500,
}

MODEL_PATH = "./best_gaussian_diffusion.pth"

In [3]:
# --- Core: Sampling Function ---
def sample_and_render(model, scheduler, device, num_samples=5, epoch=None):
    """Sample gaussians from the model and render them."""
    model.eval()
    
    # Create output directory if it doesn't exist
    os.makedirs(CONFIG["output_dir"], exist_ok=True)
    
    with torch.no_grad():
        input_dim = CONFIG["model"]["input_dim"]

        # Start from pure noise
        seed = int(time.time() * 1000) % 1000000 if epoch is None else epoch * 1000
        torch.manual_seed(seed)
        x = torch.randn(num_samples, 1000, input_dim, device=device)
        
        # tqdm for sampling progress
        iterator = tqdm(reversed(range(scheduler.num_timesteps)), desc="Sampling", leave=False)
        
        for t in iterator:
            timesteps = torch.full((num_samples,), t, device=device, dtype=torch.long)
            predicted_noise = model(x, timesteps)
            
            # Scheduler constants
            alpha_t = scheduler.alphas[t].to(device)
            alpha_cumprod_t = scheduler.alphas_cumprod[t].to(device)
            beta_t = scheduler.betas[t].to(device)
            
            # Denoising Step
            coef1 = 1 / torch.sqrt(alpha_t)
            coef2 = (1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)
            
            x = coef1 * (x - coef2 * predicted_noise)
            x = torch.clamp(x, -3.5, 3.5) # Dynamic range clipping
            
            # Add noise (Langevin dynamics) except at last step
            if t > 0:
                noise = torch.randn_like(x)
                sigma_t = torch.sqrt(beta_t)
                x += sigma_t * noise
        
        # Render
        print(f"\n[Rendering] Saving {num_samples} samples...")
        for i in range(num_samples):
            sample = x[i]
            # Denormalize
            xy, scale, rot, feat = denormalize_data(
                sample[:, 0:2], sample[:, 2:4], sample[:, 4:5], sample[:, 5:8]
            )
            
            # Ensure all tensors are float32 and contiguous
            xy = xy.contiguous().float()
            scale = scale.contiguous().float()
            rot = rot.contiguous().float()
            feat = feat.contiguous().float()
            
            # Ensure img_size is explicitly int tuple
            img_size = (int(480), int(640))
            epoch_suffix = f"_epoch{epoch}" if epoch is not None else ""
            filename = f"{CONFIG['output_dir']}/sample_{i}{epoch_suffix}"
            
            render_and_save(xy, scale, rot, feat, filename, img_size)
    
    model.train()

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianDiffusionTransformer(
        input_dim=CONFIG["model"]["input_dim"], 
        model_dim=CONFIG["model"]["model_dim"], 
        n_heads=CONFIG["model"]["n_heads"], 
        n_layers=CONFIG["model"]["n_layers"]
).to(device)

scheduler = DiffusionScheduler(num_timesteps=CONFIG["diffusion_steps"])

# Load the trained model
#model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

# Sample and render
#sample_and_render(model, scheduler, device, num_samples=5)



In [6]:
# Let's load the first batch to visualize
import time
import random
torch.manual_seed(random.randint(0, 1000000))

dataloader = create_dataloaders(CONFIG["data_dir"], batch_size=CONFIG["batch_size"])

data_iter = iter(dataloader)
batch = next(data_iter)
batch = next(data_iter)

# Visualize the first sample in the batch
sample = batch[0].to(device)  # First sample
xy, scale, rot, feat = denormalize_data(
    sample[:, 0:2], sample[:, 2:4], sample[:, 4:5], sample[:, 5:8]
)

xy = xy.contiguous().float()
scale = scale.contiguous().float()
rot = rot.contiguous().float()
feat = feat.contiguous().float()

render_and_save(xy, scale, rot, feat, f"{CONFIG['output_dir']}/first_batch_sample", (512, 512))

Found 819 files in ./data/small/


tensor([[[0.7093, 0.7096, 0.7099,  ..., 0.7679, 0.7681, 0.7684],
         [0.7095, 0.7098, 0.7101,  ..., 0.7683, 0.7684, 0.7687],
         [0.7097, 0.7100, 0.7103,  ..., 0.7686, 0.7688, 0.7691],
         ...,
         [0.5925, 0.5925, 0.5923,  ..., 0.3337, 0.6133, 0.8430],
         [0.5925, 0.5925, 0.5921,  ..., 0.3387, 0.6110, 0.8099],
         [0.5925, 0.5924, 0.5920,  ..., 0.3439, 0.6101, 0.8073]],

        [[0.7180, 0.7183, 0.7186,  ..., 0.7839, 0.7841, 0.7844],
         [0.7182, 0.7185, 0.7188,  ..., 0.7843, 0.7844, 0.7848],
         [0.7185, 0.7188, 0.7191,  ..., 0.7846, 0.7848, 0.7851],
         ...,
         [0.5618, 0.5618, 0.5617,  ..., 0.3251, 0.6056, 0.8377],
         [0.5618, 0.5618, 0.5616,  ..., 0.3306, 0.6038, 0.8051],
         [0.5618, 0.5618, 0.5615,  ..., 0.3362, 0.6034, 0.8027]],

        [[0.6956, 0.6960, 0.6964,  ..., 0.7804, 0.7805, 0.7808],
         [0.6959, 0.6963, 0.6967,  ..., 0.7807, 0.7809, 0.7812],
         [0.6962, 0.6967, 0.6971,  ..., 0.7811, 0.7813, 0.