In [None]:
import sys
sys.path.append('..')

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from diffusion_gnn.core.ddpm import DDPM
from diffusion_gnn.core.scheduler import DDPMScheduler
from diffusion_gnn.models.unet import SimpleUNet
from diffusion_gnn.data.synthetic import create_toy_dataset
from diffusion_gnn.evaluation.visualization import *


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Create small dataset for quick testing
dataset = create_toy_dataset(n_samples=100, seq_len=32)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:
# Initialize components
scheduler = DDPMScheduler(num_timesteps=100)  # Fewer steps for quick demo
model = SimpleUNet(dim=32, channels=1).to(device)
ddpm = DDPM(model, scheduler, device)

In [None]:
# Visualize original data
samples = next(iter(dataloader))
plt.figure(figsize=(12, 3))
for i in range(4):
    plt.subplot(1, 4, i+1)
    plt.plot(samples[i, 0].numpy())
    plt.title(f'Original {i+1}')
plt.suptitle('Original Training Data')
plt.show()

In [None]:
# Quick training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):  # Just a few epochs for demo
    model.train()
    total_loss = 0

    for data in dataloader:
        data = data.to(device)
        loss = ddpm.train_loss(data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch}, Loss: {total_loss/len(dataloader):.4f}")

# Cell 5: Generate samples
visualize_final_samples(ddpm, device, num_samples=4, seq_len=32)