In [1]:
import sys
sys.path.append("..")

import pytorch_lightning as pl

In [2]:
from torch import nn
from src import ImageDataset, ImageLoader, piGAN

In [3]:
optim_cfg = {
    "discriminator" : {
        "learning_rate" : 4e-4,
        "target_learning_rate" : 1e-4
    },
    "generator" : {
        "learning_rate" : 5e-5,
        "target_learning_rate" : 1e-5
    },
    "learning_rate_decay_span" : 10000
}

generator_cfg = {
    "mapping_network_kw" : {
        "depth" : 3
    },
    "siren_mlp_kw" : {
        "num_layers" : 6,
    }
}

discriminator_cfg = {
    "init_resolution" : 32,
    "max_chan" : 100,
    "pow2_bottom_layer_chans" : 8,
}

In [7]:
import torchvision.transforms as T

augmentation_list = [
    T.ColorJitter(brightness=0.1, saturation=0.1, contrast=0.1, hue=0.1),
    T.RandomResizedCrop(size=128, scale = (0.2, 1), ratio=(4/ 5, 5 / 4)),
    T.RandomPerspective(distortion_scale=0.1, p=0.3),
    T.RandomAffine(degrees=10, translate=(0.1, 0.1))
]

In [8]:
image_dataset = ImageDataset(data_dir="../images/flowers", image_size=128, augmentation_list=augmentation_list)
image_loader = ImageLoader(image_dataset=image_dataset, batch_size=2, num_workers=4)

In [9]:
import torchvision.transforms as T

image_size=128

augmentation_list = [
    T.ColorJitter(brightness=0.1, saturation=0.1, contrast=0.1, hue=0.1),
    T.RandomResizedCrop(size=image_size, scale = (0.2, 1), ratio=(4/ 5, 5 / 4)),
    T.RandomPerspective(distortion_scale=0.1, p=0.3),
    T.RandomAffine(degrees=10, translate=(0.1, 0.1)),
    T.GaussianBlur(kernel_size=3)
]

image_dataset = ImageDataset(
    data_dir="/root/.keras/datasets/flower_photos/sunflowers/", 
    image_size=image_size,
    augmentation_list=augmentation_list
)

image_loader = ImageLoader(image_dataset=image_dataset, batch_size=1, num_workers=4)

pi_GAN = piGAN(
    image_size=image_size,
    input_features=128,
    hidden_features=64,
    optim_cfg=optim_cfg,
    sample_every=2,
    generator_cfg=generator_cfg,
    discriminator_cfg=discriminator_cfg,
    image_dataset=image_dataset,
    batch_size=image_loader.batch_size,
    num_samples=1,
    loss_mode="log"
)

In [10]:
trainer = pl.Trainer(
    max_epochs=50000, 
    progress_bar_refresh_rate=20,
    accumulate_grad_batches=4
)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


In [11]:
trainer.fit(
    model=pi_GAN,
    train_dataloader=image_loader
)


  | Name               | Type              | Params
---------------------------------------------------------
0 | G                  | Generator         | 199 K 
1 | D                  | Discriminator     | 439 K 
2 | discriminator_loss | BCEWithLogitsLoss | 0     
3 | generator_loss     | BCEWithLogitsLoss | 0     
---------------------------------------------------------
638 K     Trainable params
0         Non-trainable params
638 K     Total params
2.553     Total estimated model params size (MB)
Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)



1