In [None]:
# ===============[ IMPORTS ]===============
import os

import pytorch_lightning as pl
import wandb
from diffusers import UNet2DModel
from lightning.pytorch.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary

from src.data.celeba import CelebADataModule
from src.models.diffusion import DiffusionModel

# ===============[ CONFIGS ]===============
PROJECT_NAME = "local-unet"
DATA_DIR = "./data"
LOG_DIR = "./logs"

# Data
BATCH_SIZE = 64
IMAGE_SIZE = 64
MAX_TRAIN_SAMPLES = 1000
MAX_VAL_SAMPLES = 100
NUM_WORKERS = 8
PERSISTENT_WORKERS = True
PIN_MEMORY = False

# Hyperparameters
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-2

# Sampling
NUM_DIFFUSION_SAMPLES = 6
NUM_DIFFUSION_STEPS = 50

# Training
MAX_EPOCHS = 10
CHECKPOINT_PATH = os.path.join(LOG_DIR, "checkpoints")

# ===============[ LOGGER ]===============
wandb.finish()
wandb_logger = WandbLogger(project=PROJECT_NAME, save_dir=LOG_DIR)

# ===============[ DATA ]===============
datamodule = CelebADataModule(
    data_dir=DATA_DIR,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    max_train_samples=MAX_TRAIN_SAMPLES,
    max_val_samples=MAX_VAL_SAMPLES,
    persistent_workers=PERSISTENT_WORKERS,
    pin_memory=PIN_MEMORY,
)

# ===============[ MODEL ]===============
unet = UNet2DModel(
    sample_size=IMAGE_SIZE,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 128, 128),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"),
)
model = DiffusionModel(
    net=unet,
    image_size=IMAGE_SIZE,
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    num_diffusion_samples=NUM_DIFFUSION_SAMPLES,
    num_diffusion_steps=NUM_DIFFUSION_STEPS,
)
# ===============[ CALLBACKS ]===============
callbacks = [
    ModelSummary(max_depth=2),
    ModelCheckpoint(dirpath=CHECKPOINT_PATH, save_top_k=1, monitor="val/loss", mode="min", save_last=True),
]

# ===============[ TRAINER ]===============
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    logger=wandb_logger,
    callbacks=callbacks,
    accelerator="auto",
    log_every_n_steps=10,
    enable_progress_bar=True,
)

# ===============[ TRAIN ]===============
trainer.fit(model, datamodule=datamodule)