In [8]:
import torch
from tqdm import tqdm
import sys
import torch.optim as optim
from ddpm.config import cifar10_config
from ddpm.data import get_cifar10_dataloaders
from ddpm.diffusion_model import DiffusionModel

In [9]:
from ddpm import config as _config
_config.DEBUG = False

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
train_loader, test_loader = get_cifar10_dataloaders(
    batch_size=cifar10_config.batch_size, 
)

In [10]:
model = DiffusionModel(cifar10_config).to(device)

In [6]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [7]:
for epoch in tqdm(range(10), desc="Training Progress", leave=True):
    model.train()
    batch_progress = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)

    for images, labels in batch_progress:
        images = images.to(device)
        
        # Suppress verbose output
        try:
            loss = model(images, labels)  
        except Exception as e:
            print(f"Error during model forward pass: {e}")
            continue
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_progress.set_postfix(loss=loss.item())  # Updates loss dynamically
        sys.stdout.flush()  # Forces tqdm to refresh
    
    tqdm.write(f"Epoch {epoch}, loss={loss.item():.4f}")

    # Periodically sample
    model.eval()
    samples = model.sample((16, 3, 32, 32), device=device)

Training Progress:   0%|          | 0/10 [00:00<?, ?it/s]

ResNet(), x.shape: torch.Size([64, 3, 32, 32])
ResNet(), after initial_pad, x.shape: torch.Size([64, 3, 36, 36])
ResNet(), after init_conv: torch.Size([64, 128, 36, 36])
ResBlock: torch.Size([64, 128, 36, 36]) -> torch.Size([64, 128, 36, 36])
ResNet(), down_layers-0 <class 'ddpm.res_net.ResBlock'>, torch.Size([64, 128, 36, 36]) -> torch.Size([64, 128, 36, 36])
ResBlock: torch.Size([64, 128, 36, 36]) -> torch.Size([64, 128, 36, 36])
ResNet(), down_layers-1 <class 'ddpm.res_net.ResBlock'>, torch.Size([64, 128, 36, 36]) -> torch.Size([64, 128, 36, 36])
DownSample: torch.Size([64, 128, 36, 36]) -> torch.Size([64, 128, 18, 18])
ResNet(), down_layers-2 <class 'ddpm.res_net.DownSample'>, torch.Size([64, 128, 36, 36]) -> torch.Size([64, 128, 18, 18])
ResBlock: torch.Size([64, 128, 18, 18]) -> torch.Size([64, 256, 18, 18])
ResNet(), down_layers-3 <class 'ddpm.res_net.ResBlock'>, torch.Size([64, 128, 18, 18]) -> torch.Size([64, 256, 18, 18])
ResBlock: torch.Size([64, 256, 18, 18]) -> torch.Size(

Training Progress:   0%|          | 0/10 [00:32<?, ?it/s]


KeyboardInterrupt: 