In [1]:
import sys
sys.path.append("..")

import pytorch_lightning as pl

In [3]:
from torch import nn
from src import ImageDataset, ImageLoader, piGAN

In [8]:
optim_cfg = {
    "discriminator" : {
        "learning_rate" : 4e-4,
        "target_learning_rate" : 1e-4
    },
    "generator" : {
        "learning_rate" : 5e-5,
        "target_learning_rate" : 1e-5
    },
    "learning_rate_decay_span" : 10000
}

generator_cfg = {
    "mapping_network_kw" : {
        "depth" : 3
    },
    "siren_mlp_kw" : {
        "num_layers" : 6,
    }
}

discriminator_cfg = {
    "init_resolution" : 32,
    "max_chan" : 200,
    "pow2_bottom_layer_chans" : 8,
    # "final_activation" : nn.Sigmoid()
}

In [9]:
image_dataset = ImageDataset(data_dir="../images/flowers", image_size=128)
image_loader = ImageLoader(image_dataset=image_dataset, batch_size=2, num_workers=4)

In [10]:
pi_GAN = piGAN(
    image_size=128,
    input_features=128,
    hidden_features=128,
    optim_cfg=optim_cfg,
    sample_every=20,
    generator_cfg=generator_cfg,
    discriminator_cfg=discriminator_cfg,
    image_dataset=image_dataset,
    batch_size=image_loader.batch_size,
    loss_mode="log"
)

In [11]:
trainer = pl.Trainer(
    max_epochs=50000, 
    progress_bar_refresh_rate=20,
    accumulate_grad_batches=4
)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


In [12]:
trainer.fit(
    model=pi_GAN,
    train_dataloader=image_loader
)


  | Name               | Type              | Params
---------------------------------------------------------
0 | G                  | Generator         | 199 K 
1 | D                  | Discriminator     | 247 K 
2 | discriminator_loss | BCEWithLogitsLoss | 0     
3 | generator_loss     | BCEWithLogitsLoss | 0     
---------------------------------------------------------
446 K     Trainable params
0         Non-trainable params
446 K     Total params
1.787     Total estimated model params size (MB)
Epoch 12:   0%|          | 0/17 [00:11<?, ?it/s, loss=7.37, v_num=57, loss_D=4.2e-5, loss_G=10.60]


1