<a href="https://colab.research.google.com/github/96jonesa/StyleGan2-Colab-Demo/blob/master/tool_for_training_small_set_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# You're probably looking for this:

IF YOU WOULD LIKE TO SEE A COMPARISON OF RESULTS FROM TRAINING ON VARIOUS
DATASETS UNDER VARIOUS CONFIGURATIONS, CHECK OUT THIS NOTEBOOK:

https://colab.research.google.com/drive/1uwPlY-4P_6fJ59SFRtgZLebVGgwGrUQu?usp=sharing

# What is this?

This is a simple demo of use of an open-source PyTorch implementation of StyleGAN2

https://github.com/lucidrains/stylegan2-pytorch

setup for training using Colab's free GPU resources and Google Drive. Citations at bottom.

The GitHub repo I made for this project is available at:

https://github.com/96jonesa/StyleGan2-Colab-Demo

# Where do the files go?
 
Training results and models are saved to the local runtime's 'results' and 'models' directories, or to your Google Drive in subdirectories (of the same names) of a parent directory named 'StyleGan2_small_set_demo'.

Local runtime files can be accessed by clicking on the folder icon found on the toolbar to the left.

# Using your Google Drive:

IF YOU CHOOSE TO USE YOUR GOOGLE DRIVE for training, then you will be prompted in Code Cell 2 of this notebook to authorize access. You must click a link, copy a code, and paste it into the input box below Code Cell 2. Hit enter.

# HOW TO USE:

0. Login to Google (Drive)

1. Click 'Copy to Drive' above to make a runnable copy of this notebook.
2. Run this cell (click the play button in top left of cell) to connect to a runtime instance.
3. Navigate to 'Runtime > Change Runtime Type > Hardware Accelerator' and select GPU.
4. Modify the variables found in the cell below to select behavior of demo.
5. Run all cells ('Runtime > Run All').
6. IF USING YOUR GOOGLE DRIVE, FOLLOW INSTRUCTIONS FOUND IN ABOVE CELL.

In [None]:
# 'celeba', 'afhq', 'metafaces', 'cifar10', 'afhq_dog', 'afhq_cat', 'afhq_wild',
# 'cifar10_airplane', 'cifar10_automobile', 'cifar10_bird', 'cifar10_cat', 'cifar10_deer',
# 'cifar10_dog', 'cifar10_frog', 'cifar10_horse', 'cifar10_ship', 'cifar10_truck'
USE_DATASET = 'celeba'

TRAINING_FROM_SCRATCH = False # set True if training from scratch, False if training for last checkpoint
MODEL_NAME = 'default'
MODEL_NUM_TRAIN_STEPS = 3000
LOW_NETWORK_CAPACITY = False # set True to use significantly lower network capacity
USE_GOOGLE_DRIVE_FOR_TRAINING = False

# 'none', 'first', 'every'
USE_ATTENTION_LAYERS = 'none' # which layers do you want attention applied to?

MODEL_AUGMENTATION_PROBABILITY = 0.0
MODEL_LEARNING_RATE = 2e-4
MODEL_IMAGE_SIZE = 128

In [None]:
# Mounts your Google Drive so files can be saved to it. Note that this also allows
# files to be read from it, so only authorize this if you are comfortable doing so
# and/or using a disposable Google Drive account.

if USE_GOOGLE_DRIVE_FOR_TRAINING:
    from google.colab import drive
    drive.mount('/content/drive')

    !mkdir -p "/content/drive/My Drive/StyleGan2_small_set_demo"

In [None]:
# Prints information about the GPU allocated by Colab.
# Possible models of GPU are K80, T4, and P100. K80 is relatively slow.

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

In [None]:
MODEL_NETWORK_CAPACITY = 16
if LOW_NETWORK_CAPACITY:
    MODEL_NETWORK_CAPACITY = 4

MODEL_ATTENTION_LAYERS = []
if USE_ATTENTION_LAYERS == 'first':
    MODEL_ATTENTION_LAYERS = "[1]"
elif USE_ATTENTION_LAYERS == 'every':
    MODEL_ATTENTION_LAYERS = "[1,2,3,4,5,6]"

MODEL_NAME = USE_DATASET + '_' + MODEL_NAME

In [None]:
# Installs the architecture from:
# https://github.com/lucidrains/stylegan2-pytorch

!pip install stylegan2_pytorch==0.17.1

In [None]:
# Utilities for downloading publicly shared Google Drive files (from my Google Drive).

import requests

def download_file_from_google_drive(id, destination):
    URL = 'https://docs.google.com/uc?export=download'

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, 'wb') as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)

In [None]:
!pip install linformer

In [None]:
# Downloads and unzips the selected dataset from my Google Drive.

import zipfile

if USE_DATASET == 'celeba':

    file_id = '1tF3yRlv5VZx0hgZ5RsNdOiAlSrrIPJNq'
    destination = 'celeba.zip'
    download_file_from_google_drive(file_id, destination)
    zip_ref = zipfile.ZipFile('celeba.zip', 'r')
    zip_ref.extractall('data/celeba')
    zip_ref.close()

elif USE_DATASET in ['afhq', 'afhq_dog', 'afhq_cat', 'afhq_wild']:

    file_id = '1PFZbFOQzGhXUmF6TxAzboEP7fOn8zyHa'
    destination = 'afhq.zip'
    download_file_from_google_drive(file_id, destination)
    zip_ref = zipfile.ZipFile('afhq.zip', 'r')
    zip_ref.extractall('data/afhq')
    zip_ref.close()

elif USE_DATASET == 'metfaces':
    file_id = '1r7XMa8gNZqwqEsgf9t7MFWP7BAgy74jm'
    destination = 'metfaces.zip'
    download_file_from_google_drive(file_id, destination)
    zip_ref = zipfile.ZipFile('metfaces.zip', 'r')
    zip_ref.extractall('data/metfaces')
    zip_ref.close()

elif USE_DATASET == 'brecahad':
    file_id = '1lnZd9ujC3FecVc9dmjOm_XUX3ei3B_69'
    destination = 'brecahad.zip'
    download_file_from_google_drive(file_id, destination)
    zip_ref = zipfile.ZipFile('brecahad.zip', 'r')
    zip_ref.extractall('data/brecahad')
    zip_ref.close()

elif USE_DATASET in ['cifar10', 'cifar10_airplane', 'cifar10_automobile', 'cifar10_bird', 'cifar10_cat', 'cifar10_deer', 'cifar10_dog', 'cifar10_frog', 'cifar10_horse', 'cifar10_ship', 'cifar10_truck']:
    file_id = '1T_maRdj_fgXLychhORRivbyOflhzISGV'
    destination = 'cifar10.zip'
    download_file_from_google_drive(file_id, destination)
    zip_ref = zipfile.ZipFile('cifar10.zip', 'r')
    zip_ref.extractall('data/cifar10')
    zip_ref.close()

In [None]:
# Chooses the appropriate subdirectory of dataset for training.

# uses CelebA dataset by default (if chosen, or if invalid dataset name)
MODEL_DATA_DIR = 'data/celeba/img_align_celeba'

if USE_DATASET == 'metfaces':
    MODEL_DATA_DIR = 'data/metfaces/images'
elif USE_DATASET == 'brecahad':
    MODEL_DATA_DIR = 'data/brecahad/images'
elif USE_DATASET == 'afhq':
    MODEL_DATA_DIR = 'data/afhq/afhq/train'
elif USE_DATASET == 'cifar10':
    MODEL_DATA_DIR = 'data/cifar10/cifar10/cifar10/train'
elif USE_DATASET in ['afhq_dog', 'afhq_cat', 'afhq_wild']:
    MODEL_DATA_DIR = 'data/afhq/afhq/train/' + USE_DATASET[5:]
elif USE_DATASET in ['cifar10_airplane', 'cifar10_automobile', 'cifar10_bird', 'cifar10_cat', 'cifar10_deer', 'cifar10_dog', 'cifar10_frog', 'cifar10_horse', 'cifar10_ship', 'cifar10_truck']:
    MODEL_DATA_DIR = 'data/cifar10/cifar10/cifar10/train/' + USE_DATASET[8:]

In [None]:
# Establish directories for custom models.

CUSTOM_RESULTS_DIR = './results'
CUSTOM_MODELS_DIR = './models'

if USE_GOOGLE_DRIVE_FOR_TRAINING:
    CUSTOM_RESULTS_DIR = '"/content/drive/My Drive/StyleGan2_small_set_demo/results"'
    CUSTOM_MODELS_DIR = '"/content/drive/My Drive/StyleGan2_small_set_demo/models"'

In [None]:
# Train custom models.

if TRAINING_FROM_SCRATCH:
    !stylegan2_pytorch --data {MODEL_DATA_DIR} --name {MODEL_NAME} --new --network_capacity {MODEL_NETWORK_CAPACITY} --batch_size 4 \
        --gradient_accumulate_every 4 --num_train_steps {MODEL_NUM_TRAIN_STEPS} --attn_layers {MODEL_ATTENTION_LAYERS} --image_size {MODEL_IMAGE_SIZE} \
        --aug_prob {MODEL_AUGMENTATION_PROBABILITY} --results_dir {CUSTOM_RESULTS_DIR} --models_dir {CUSTOM_MODELS_DIR} --learning_rate {MODEL_LEARNING_RATE}
else:
    !stylegan2_pytorch --data {MODEL_DATA_DIR} --name {MODEL_NAME} --network_capacity {MODEL_NETWORK_CAPACITY} --batch_size 4 \
        --gradient_accumulate_every 4 --num_train_steps {MODEL_NUM_TRAIN_STEPS} --attn_layers {MODEL_ATTENTION_LAYERS} --image_size {MODEL_IMAGE_SIZE} \
        --aug_prob {MODEL_AUGMENTATION_PROBABILITY} --results_dir {CUSTOM_RESULTS_DIR} --models_dir {CUSTOM_MODELS_DIR} --learning_rate {MODEL_LEARNING_RATE}

# Parameters accepted by model:

In [None]:
# parameter                 | default   | description
#                           |           |
# data                      | ./data    | directory containing data
# results_dir               | ./results | directory for checkpoint sample images
# models_dir                | ./models  | directory for checkpoint models (saves to and loads from here)
# name                      | default   | name to identify model (all outputs will be saved to results_dir/name and models_dir/name)
# new                       | False     | if True then starts from scratch, else loads from saved checkpoint model
# load_from                 | -1        | if -1 then loads from most recent checkpoint, else loads from checkpoint number load_from
# image_size                | 128       | size of (square) images generated and for resizing of data
# network_capacity          | 16        | affects number of nodes per layer - decrease to train faster with lower output quality
# transparent               | False     | if True then uses RGBA, else uses RGB
# batch_size                | 3         | number of images per mini-batch (larger uses more GPU memory)
# gradient_accumulate_every | 5         | number of mini-batches to process before optimizing (choice depends on batch_size)
# num_train_steps           | 150000    | total steps of forward prop (counting starts from number of steps completed in loaded checkpoint)
# learning_rate             | 2e-4      | learning rate
# num_workers               | None      | if None then uses as many workers as possible from available CPU cores (for data loading)
# save_every                | 1000      | every save_every steps, a checkpoint model and sample images are saved
# generate                  | False     | if True then generates sample images from loaded model instead of training
# generate_interpolation    | False     | if True then generates .gif interpolation from loaded model instead of training, else does not
# num_image_tiles           | 8         | generated samples will be a grid of (num_image_tiles x num_image_tiles) images
# trunc_psi                 | 0.75      | affects how far generate images can be from average image (increase for more diversity) w_new = psi * w + (1 - psi) * w_avg
# fp16                      | False     | if True then uses fp16 half-precision to lower GPU memory usage (requires apex), else uses full-precision
# cl_reg                    | False     | if True then uses contrastive learning on discriminator (possibly improves stability and quality), else does not
# fq_layers                 | []        | list of layers to apply feature (intermediate representation) vector quantization to (can improve results, but not dramatically)
# fq_dict_size              | 256       | dictionary size for feature quantization
# attn_layers               | []        | list of layers to apply self-attention to while training (can be empty; do not use spaces; up to log2(image_size) - 1 layers)
# no_const                  | False     | if True then 4x4 block is learned from style vector, else styles a constant learned 4x4 block through progressive upsampling
# aug_prob                  | 0.0       | probability of applying differentiable augmentation to images fed to discriminator

# CITATIONS:

```
@inproceedings{choi2020starganv2,
  title={StarGAN v2: Diverse Image Synthesis for Multiple Domains},
  author={Yunjey Choi and Youngjung Uh and Jaejun Yoo and Jung-Woo Ha},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2020}
}

@inproceedings{liu2015faceattributes,
 title = {Deep Learning Face Attributes in the Wild},
 author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou},
 booktitle = {Proceedings of International Conference on Computer Vision (ICCV)},
 month = {December},
 year = {2015} 
}

@article{Karras2019stylegan2,
  title   = {Analyzing and Improving the Image Quality of {StyleGAN}},
  author  = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
  journal = {CoRR},
  volume  = {abs/1912.04958},
  year    = {2019},
}

@misc{zhao2020feature,
    title   = {Feature Quantization Improves GAN Training},
    author  = {Yang Zhao and Chunyuan Li and Ping Yu and Jianfeng Gao and Changyou Chen},
    year    = {2020}
}

@misc{chen2020simple,
    title   = {A Simple Framework for Contrastive Learning of Visual Representations},
    author  = {Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton},
    year    = {2020}
}

@article{,
  title     = {Oxford 102 Flowers},
  author    = {Nilsback, M-E. and Zisserman, A., 2008},
  abstract  = {A 102 category dataset consisting of 102 flower categories, commonly occuring in the United Kingdom. Each class consists of 40 to 258 images. The images have large scale, pose and light variations.}
}

@article{afifi201911k,
  title   = {11K Hands: gender recognition and biometric identification using a large dataset of hand images},
  author  = {Afifi, Mahmoud},
  journal = {Multimedia Tools and Applications}
}

@misc{zhang2018selfattention,
    title   = {Self-Attention Generative Adversarial Networks},
    author  = {Han Zhang and Ian Goodfellow and Dimitris Metaxas and Augustus Odena},
    year    = {2018},
    eprint  = {1805.08318},
    archivePrefix = {arXiv}
}

@article{shen2019efficient,
  author    = {Zhuoran Shen and
               Mingyuan Zhang and
               Haiyu Zhao and
               Shuai Yi and
               Hongsheng Li},
  title     = {Efficient Attention: Attention with Linear Complexities},
  journal   = {CoRR},  
  year      = {2018},
  url       = {http://arxiv.org/abs/1812.01243},
}

@misc{zhao2020image,
    title  = {Image Augmentations for GAN Training},
    author = {Zhengli Zhao and Zizhao Zhang and Ting Chen and Sameer Singh and Han Zhang},
    year   = {2020},
    eprint = {2006.02595},
    archivePrefix = {arXiv}
}

@misc{karras2020training,
    title   = {Training Generative Adversarial Networks with Limited Data},
    author  = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    year    = {2020},
    eprint  = {2006.06676},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

@article{article,
author = {Krizhevsky, Alex},
year = {2012},
month = {05},
pages = {},
title = {Learning Multiple Layers of Features from Tiny Images},
journal = {University of Toronto}
}

@misc{karras2020training,
    title={Training Generative Adversarial Networks with Limited Data},
    author={Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    year={2020},
    eprint={2006.06676},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

@article{article,
author = {Aksac, Alper and Demetrick, Douglas and Ozyer, Tansel},
year = {2019},
month = {12},
pages = {},
title = {BreCaHAD: a dataset for breast cancer histopathological annotation and diagnosis},
volume = {12},
journal = {BMC Research Notes},
doi = {10.1186/s13104-019-4121-7}
}
```


