In [1]:
!pip install pytorch_lightning

In [5]:
!git clone https://github.com/Godofnothing/pi-GAN

Cloning into 'pi-GAN'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 32 (delta 0), reused 25 (delta 0), pack-reused 0[K
Unpacking objects: 100% (32/32), done.


In [2]:
%cd pi-GAN
!git checkout master

/content/pi-GAN
Already on 'master'
Your branch is up to date with 'origin/master'.


In [3]:
import pytorch_lightning as pl
from torch import nn
from src import ImageDataset, ImageLoader, piGAN

In [5]:
import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, fname='flower_photos', untar=True)
data_dir = pathlib.Path(data_dir)

In [6]:
optim_cfg = {
    "discriminator" : {
        "learning_rate" : 4e-4,
        "target_learning_rate" : 1e-4
    },
    "generator" : {
        "learning_rate" : 5e-5,
        "target_learning_rate" : 1e-5
    },
    "learning_rate_decay_span" : 10000
}

generator_cfg = {
    "mapping_network_kw" : {
        "depth" : 3
    },
    "siren_mlp_kw" : {
        "num_layers" : 6,
    }
}

discriminator_cfg = {
    "init_resolution" : 32,
    "max_chan" : 128,
    "pow2_bottom_layer_chans" : 11,
    "final_activation" : nn.Sigmoid()
}

In [7]:
image_size=128

image_dataset = ImageDataset(data_dir="/root/.keras/datasets/flower_photos/sunflowers/", image_size=image_size)
image_loader = ImageLoader(image_dataset=image_dataset, batch_size=1, num_workers=4)

pi_GAN = piGAN(
    image_size=128,
    input_features=512,
    hidden_features=256,
    optim_cfg=optim_cfg,
    sample_every=20,
    generator_cfg=generator_cfg,
    discriminator_cfg=discriminator_cfg,
    image_dataset=image_dataset,
    batch_size=image_loader.batch_size,
    loss_mode="log"
)

In [8]:
trainer = pl.Trainer(
    gpus=1,
    amp_level='O2', 
    precision=16,
    max_epochs=50000, 
    progress_bar_refresh_rate=20,
    accumulate_grad_batches=4
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.


In [9]:
trainer.fit(
    model=pi_GAN,
    train_dataloader=image_loader
)


  | Name | Type          | Params
---------------------------------------
0 | G    | Generator     | 1.6 M 
1 | D    | Discriminator | 9.8 M 
---------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
45.937    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)





1