In [1]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import logging
from utils import *
from tqdm import tqdm
from modules import UNet, Diffusion

In [11]:
%load_ext autoreload
%autoreload 2

In [2]:
def train(args):
    setup_logging(args.run_name)
    dataloader = get_data(args)

    mse = nn.MSELoss()
    model = UNet().to(args.device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    diffusion = Diffusion(img_size=args.image_size, device=args.device)
    
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(args.device)
            t = diffusion.sample_timesteps(images.shape[0]).to(args.device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))

In [7]:
def launch():
    import argparse

    parser = argparse.ArgumentParser()
    args = parser.parse_args(args=[])
    args.run_name = "DDPM_Uncondtional"
    args.epochs = 500
    args.batch_size = 12
    args.image_size = 64
    args.dataset_path = r"C:\Users\issac\OneDrive\Projects\ImageGen\landscape_imgs"
    args.device = "cpu"
    args.lr = 3e-4
    train(args)

In [12]:
launch()

11:36:37 - INFO: Starting epoch 0:
  0%|          | 0/360 [00:00<?, ?it/s]


AssertionError: Torch not compiled with CUDA enabled