In [None]:
# ===============[ IMPORTS ]===============
import pytorch_lightning as pl
import wandb
from lightning.pytorch.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary

from src.data.celeba import CelebADataModule
from src.models.diffusion import DiffusionModel
from src.models.rin import RINModel

# ===============[ WANDB LOGGER ]===============
wandb.finish()
wandb_logger = WandbLogger(project="local-rin", save_dir="./logs")

# ===============[ DATA ]===============
datamodule = CelebADataModule(
    data_dir="./data",
    image_size=64,
    batch_size=64,
    num_workers=8,
    max_train_samples=None,
    max_val_samples=None,
    persistent_workers=True,
    pin_memory=False,
)

# ===============[ MODEL ]===============
net = RINModel(
    image_size=64,
    patch_size=8,
    latent_dim=256,
    interface_dim=128,
    num_latents=64,
    num_blocks=2,
    block_depth=1,
    num_heads=4,
)

model = DiffusionModel(
    net=net,
    image_size=64,
    lr=1e-4,
    weight_decay=1e-2,
    num_diffusion_samples=3,
    num_diffusion_steps=50,
)

# ===============[ TRAINING ]===============
callbacks = [
    ModelSummary(max_depth=2),
    ModelCheckpoint(dirpath="./logs/checkpoints", save_top_k=1, monitor="val/loss", mode="min", save_last=True),
]

trainer = pl.Trainer(
    max_epochs=10,
    logger=wandb_logger,
    callbacks=callbacks,
    accelerator="auto",
    log_every_n_steps=10,
    enable_progress_bar=True,
)

trainer.fit(model, datamodule=datamodule)