# 02_train_diffusion.ipynb

Train a TinyUNet with diffusion noise prediction.

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from src.models import TinyUNet
from src.diffusion import SimpleDiffusion

# Load images
imgs = torch.load('../results/denoised/noisy_imgs.pt') if False else None  # placeholder for reloading
import os
from PIL import Image
from torchvision import transforms
img_dir = '../data/processed'
img_size = 32
transform = transforms.Compose([
    transforms.ToTensor()
])
img_list = []
for fname in os.listdir(img_dir):
    if fname.endswith('.jpg') or fname.endswith('.png'):
        img = Image.open(os.path.join(img_dir, fname)).convert('RGB')
        img = img.resize((img_size, img_size))
        img = transform(img)
        img_list.append(img)
imgs = torch.stack(img_list)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
diffusion = SimpleDiffusion(device=device)

dataset = TensorDataset(imgs)
loader = DataLoader(dataset, batch_size=8, shuffle=True)

epochs = 5
for epoch in range(epochs):
    losses = []
    for batch in loader:
        x = batch[0].to(device)
        t = diffusion.sample_timesteps(x.size(0)).to(device)
        x_noisy, noise = diffusion.add_noise(x, t)
        pred_noise = model(x_noisy, t)
        loss = ((pred_noise - noise)**2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f'Epoch {epoch+1}, Loss: {sum(losses)/len(losses):.4f}')

# Save model
torch.save(model.state_dict(), '../results/denoised/tinyunet.pth')