In [1]:
import sys
sys.path.append("..")

import pytorch_lightning as pl

In [2]:
from torch import nn
from src import ImageDataset, ImageLoader, piGAN

In [4]:
optim_cfg = {
    "discriminator" : {
        "learning_rate" : 4e-4,
        "target_learning_rate" : 1e-4
    },
    "generator" : {
        "learning_rate" : 5e-5,
        "target_learning_rate" : 1e-5
    },
    "learning_rate_decay_span" : 10000
}

generator_cfg = {
    "mapping_network_kw" : {
        "depth" : 3
    },
    "siren_mlp_kw" : {
        "num_layers" : 5,
    }
}

discriminator_cfg = {
    "init_resolution" : 32,
    "max_chan" : 400,
    "pow2_bottom_layer_chans" : 7,
    "final_activation" : nn.Sigmoid()
}

In [5]:
image_dataset = ImageDataset(data_dir="../images", image_size=128)
image_loader = ImageLoader(image_dataset=image_dataset, batch_size=1, num_workers=4)

In [6]:
pi_GAN = piGAN(
    image_size=128,
    dim_input=64,
    dim_hidden=64,
    optim_cfg=optim_cfg,
    sample_every=20,
    generator_cfg=generator_cfg,
    discriminator_cfg=discriminator_cfg,
    image_dataset=image_dataset,
    batch_size=image_loader.batch_size,
    loss_mode="log"
)

In [7]:
trainer = pl.Trainer(
    max_epochs=50000, 
    progress_bar_refresh_rate=20,
    accumulate_grad_batches=4
)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


In [8]:
trainer.fit(
    model=pi_GAN,
    train_dataloader=image_loader
)


  | Name | Type          | Params
---------------------------------------
0 | G    | Generator     | 46.3 K
1 | D    | Discriminator | 319 K 
---------------------------------------
365 K     Trainable params
0         Non-trainable params
365 K     Total params
1.463     Total estimated model params size (MB)
Epoch 25:   0%|          | 0/3 [00:04<?, ?it/s, loss=1.99, v_num=42, loss_D=0.440, loss_G=3.400]


1