In [None]:
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from datasets.Larynx_DataModule import Larynx_DataModule
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from lightning_modules.ae_msssim_acai import AE_MSSSIM_ACAI
from lightning_modules.igd import IGD
from lightning_modules.vae_msssim_acai import VAE_MSSSIM_ACAI
from lightning_modules.ae import AE
from lightning_modules.vae import VAE
from lightning_modules.ae_msssim import AE_MSSSIM
from lightning_modules.vae_msssim import VAE_MSSSIM


In [None]:
# Set parameters
batch_size = 4
epochs = 10
architecture = 'AE'
latent_size = 256
spatial_size = 128
gpu = 1
dataset_dir = 'DATA'
output_dir = 'OUTPUT'


In [None]:
# Initialize data module
dataset_root = dataset_dir
datamodule = Larynx_DataModule(data_dir=dataset_root, batch_size=batch_size, spatial_size=spatial_size)


In [None]:
# Initialize model
rho = 0.15
lambda_fool = 0.1
gamma = 0.2
if architecture == 'AE':
    model = AE(latent_size)
elif architecture == 'VAE':
    model = VAE(latent_size)
elif architecture == 'AE_MSSSIM':
    model = AE_MSSSIM(latent_size, rho)
elif architecture == 'VAE_MSSSIM':
    model = VAE_MSSSIM(latent_size, rho)
elif architecture == 'AE_MSSSIM_ACAI':
    model = AE_MSSSIM_ACAI(latent_size, rho, lambda_fool, gamma)
elif architecture == 'VAE_MSSSIM_ACAI':
    model = VAE_MSSSIM_ACAI(latent_size, rho, lambda_fool, gamma)
elif architecture == 'IGD':
    model = IGD(latent_size, rho, lambda_fool, gamma)


In [None]:
# Setup logger
experiment_name = architecture
root_log_dir = os.path.join(output_dir, experiment_name)
train_logger = TensorBoardLogger(save_dir=root_log_dir, name='pretraining')


In [None]:
# Setup checkpoint callback
checkpoint_dir = os.path.join(root_log_dir, 'checkpoints')
checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename='{epoch:02d}',
    save_last=True,
    every_n_epochs=10,
)


In [None]:
# Initialize trainer
trainer = pl.Trainer(
    accelerator='cpu',
    devices=gpu,
    logger=train_logger,
    fast_dev_run=False,
    num_sanity_val_steps=0,
    log_every_n_steps=20,
    callbacks=[checkpoint_callback],
    max_epochs=epochs
)


In [None]:
# Train the model
trainer.fit(model, datamodule)
