# Training of the diffusion model

In [None]:
import mlflow
import mlflow.pytorch

import pytorch_lightning as pl

import torch

from pytorch_lightning.callbacks import LearningRateMonitor

from pl_module import PL_Module
from dataloader import get_dataloader

from diffusers import UNet2DModel, DDPMScheduler

# settings:
torch.set_float32_matmul_precision("medium")

# logging:
mlflow.pytorch.autolog(
    checkpoint_save_best_only=False,
    checkpoint_save_weights_only=True
)

mlflow.end_run()
mlflow.set_experiment("DIFFUSION_MODEL")
mlflow.start_run()

## Hyperparameters

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

n_workers = 4

epochs = 100
batch_size = 8

lr = 1e-4
betas = (0.5, 0.999)
beta_schedule = 'linear'

nc = 3
image_size = (64, 64)

num_train_timesteps = 1000
num_inference_timesteps = 1000

data_path = './data/train/dataset/'

n_valid = 16

## Initialization

In [None]:
dataloader = get_dataloader(data_path, batch_size, n_workers)

model = UNet2DModel(
    sample_size=image_size,
    in_channels=nc,
    out_channels=nc,
    layers_per_block=2,
    block_out_channels=(128, 128, 128, 256, 256, 512, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
)

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_train_timesteps,
    beta_schedule=beta_schedule
)

module = PL_Module(
    model=model, 
    noise_scheduler=noise_scheduler,
    lr=lr,
    betas=betas,
    n_valid=n_valid,
    num_inference_timesteps=num_inference_timesteps
)

# callbacks:
callbacks = []
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks.append(lr_monitor)

# training:
trainer = pl.Trainer(
    accelerator=device,
    devices=1,
    max_epochs=epochs,
    precision="16-mixed",
    logger=True,
    callbacks=callbacks
)

## Start the Training

In [None]:
trainer.fit(module, dataloader)