In [1]:
import torch
from lightning.pytorch import Trainer
from adell_mri.modules.gan.gan.style import ProgressiveDiscriminator, ProgressiveGenerator
from adell_mri.modules.gan.gan.pl import ProGANPL
import monai.transforms
import monai.data

from pathlib import Path

path = "/mnt/big_disk/data/celeba/img_align_celeba/img_align_celeba/"
all_images = [{"image": x} for x in Path(path).rglob("*jpg")]

crop_size = (128, 128)
transform = monai.transforms.Compose([
    monai.transforms.LoadImaged(keys="image"),
    monai.transforms.EnsureChannelFirstd(keys="image"),
    monai.transforms.CenterSpatialCropd(keys="image", roi_size=crop_size),
    monai.transforms.ScaleIntensityd(keys="image", minv=-1, maxv=1)
])

dataset = monai.data.CacheDataset(
    all_images[:250], transform = transform,
    num_workers=8)

Loading dataset: 100%|██████████| 250/250 [00:00<00:00, 398.86it/s]


In [2]:
batch_size = 8

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=8,
    shuffle=True,
)

In [3]:
input_channels = 3
depths = [64, 128, 256]
output_channels = depths[-1]
n_levels = len(depths)
generator_input_size = (
    crop_size[0] // (2 ** (n_levels + 1)), 
    crop_size[1] // (2 ** (n_levels + 1))
)

generator = ProgressiveGenerator(
    n_dim=2, 
    input_channels=output_channels, 
    output_channels=input_channels,
    depths=depths[::-1],
    equalized_learning_rate=True,
)

discriminator = ProgressiveDiscriminator(
    n_dim=2, 
    input_channels=input_channels, 
    output_channels=1,
    depths=depths,
    minibatch_std=True,
    equalized_learning_rate=True,
    
)

max_epochs = 500
steps_per_epoch = len(data_loader)

pl_progan = ProGANPL(
    generator=generator, 
    discriminator=discriminator, 
    gradient_penalty_lambda=10.0,
    steps_per_epoch=steps_per_epoch,
    epochs=max_epochs,
    generator_input_size=generator_input_size,
    epochs_per_level=10,
)

In [None]:
from adell_mri.utils.pl_utils import get_logger
from lightning.pytorch.callbacks import RichProgressBar
from random import randint

logger = get_logger(
    f"progan-{randint(0, 10000)}",
    summary_dir="logs",
    project_name="ProGAN-dev",
    resume="none",
    logger_type="wandb"
)

trainer = Trainer(
    max_epochs=max_epochs,
    accelerator="gpu",
    devices=[0],
    log_every_n_steps=10,
    #logger=logger,
    precision="bf16-mixed",
    callbacks=RichProgressBar(),
)

trainer.fit(pl_progan, data_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade

Output()