# Training of the wasserstein GAN with gradient penalty

In [None]:
import mlflow
import mlflow.pytorch

import pytorch_lightning as pl

import torch

from pytorch_lightning.callbacks import LearningRateMonitor

from pl_module import PL_Module
from dataloader import get_dataloader
from model import Generator, Critic

# settings:
torch.set_float32_matmul_precision("medium")

# logging:
mlflow.pytorch.autolog(
    checkpoint_save_best_only=False,
    checkpoint_save_weights_only=True
)

mlflow.set_experiment("wgan")
mlflow.start_run()

## Hyperparameters

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

n_workers = 4

epochs = 100
batch_size = 16

lr = 0.0002
betas = (0.5, 0.999)
lambda_gp = 10

nz = 100
ngf = 64
ndf = 64

data_path = "./data/train/dataset"

n_valid = 16

## Initialization

In [None]:
dataloader = get_dataloader(data_path, batch_size, n_workers)

generator = Generator(nz, ngf)
critic = Critic(ndf)

module = PL_Module(generator, critic, lr, betas, lambda_gp, nz, n_valid)

# callbacks:
callbacks = []
lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_monitor)

# training:
trainer = pl.Trainer(
    accelerator=device,
    devices=1,
    max_epochs=epochs,
    precision="16-mixed",
    logger=True,
    callbacks=callbacks
)

## Start the Training

In [None]:
trainer.fit(module, dataloader)