# FastGAN-pytorch (NathanGAN2) training notebook for use with Google Colab

Copyright 2021 by Nathan Gillispie, [released under GNU General Public License v3.0](https://github.com/NathanGillispie/FastGAN-pytorch/blob/main/LICENSE)

Originally began from [Jeff Heaton's google colab file](https://github.com/jeffheaton/present/blob/master/youtube/gan/colab_gan_train.ipynb) but changed enough to become a separate work. Thanks Jeff.

 - 1024x1024 Tesla K80 = 2.4517 it/s = .4079 s/it


In [25]:
!nvidia-smi

Fri Dec 31 04:41:44 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P8    27W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Set up the environment

G-Drive is used to save network models as well as recorded images.

In [26]:
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    COLAB = True
    print("Note: using Google CoLab")
except:
    print("Note: not using Google CoLab")
    COLAB = False

Mounted at /content/drive
Note: using Google CoLab


My version of FastGAN is not necessarily better although I prefer my repo and the FastGAN needs to be placed here anyways.

In [None]:
!git clone https://github.com/NathanGillispie/FastGAN-pytorch

Optionally check dependency versions. Google has the tendency to change their packages frequently and it might help to make sure all the required ones are there. This notebook could break in a matter of months so if you are accessing this after 2021 beware, I strongly recommend checking packages if training doesn't work.


In [None]:
!pip show tqdm
!pip show scipy
!pip show scikit-image
!pip show ipdb
!pip show pandas
!pip show lmdb
!pip show opencv-python

This package is not included in the default runtime so it needs to be installed.

In [None]:
!pip install ipdb==0.13.4

Use this general framework for downgrading packages (although for panda as of today this is an upgrade) if and when google changes it's packages. Use the requirements.txt file as a guide.

In [None]:
# Only required if downgrading
!pip uninstall pandas -y
!pip install pandas==1.2.1

### Check dataset and cd into FastGAN

ensure you change the location of your dataset in code 

In [None]:
!ls /content/drive/MyDrive/data/NathanGAN/dataset
!cd /content/FastGAN-pytorch
!ls

10.jpg	14.jpg	18.jpg	21.jpg	25.jpg	29.jpg	32.jpg	6.jpg
11.jpg	15.jpg	19.jpg	22.jpg	26.jpg	2.jpg	3.jpg	7.jpg
12.jpg	16.jpg	1.jpg	23.jpg	27.jpg	30.jpg	4.jpg	8.jpg
13.jpg	17.jpg	20.jpg	24.jpg	28.jpg	31.jpg	5.jpg	9.jpg
drive  FastGAN-pytorch	sample_data


# IMPORTANT: Replace train.py with code below

Do this in the FastGAN-pytorch folder

In [None]:
from pickle import FALSE
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision import utils as vutils
import numpy.random as nprand
import argparse
import random
from tqdm import tqdm
import os
from models import weights_init, Discriminator, Generator
from operation import copy_G_params, load_params, get_dir
from operation import ImageFolder, InfiniteSamplerWrapper
from diffaug import DiffAugment
policy = 'color,translation'
import lpips
percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)


#torch.backends.cudnn.benchmark = True


def crop_image_by_part(image, part):
    hw = image.shape[2]//2
    if part==0:
        return image[:,:,:hw,:hw]
    if part==1:
        return image[:,:,:hw,hw:]
    if part==2:
        return image[:,:,hw:,:hw]
    if part==3:
        return image[:,:,hw:,hw:]

def train_d(net, data, label="real"):
    """Train function of discriminator"""
    if label=="real":
        part = random.randint(0, 3)
        pred, [rec_all, rec_small, rec_part] = net(data, label, part=part)
        err = F.relu(  torch.rand_like(pred) * 0.2 + 0.8 -  pred).mean() + \
            percept( rec_all, F.interpolate(data, rec_all.shape[2]) ).sum() +\
            percept( rec_small, F.interpolate(data, rec_small.shape[2]) ).sum() +\
            percept( rec_part, F.interpolate(crop_image_by_part(data, part), rec_part.shape[2]) ).sum()
        err.backward()
        return pred.mean().item(), rec_all, rec_small, rec_part
    else:
        pred = net(data, label)
        err = F.relu( torch.rand_like(pred) * 0.2 + 0.8 + pred).mean()
        err.backward()
        return pred.mean().item()

def seed2vec(seed, batch_size, noise_dim):
    return nprand.RandomState(seed).randn(batch_size, noise_dim)

def train(args):
    delete_old_ckpts = True
    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.6
    use_cuda = True
    dataloader_workers = 2
    current_iteration = args.start_iter
    save_interval = 10
    saved_model_folder = '/content/drive/MyDrive/data/NathanGAN/train_results/NathanGAN/models'
    saved_image_folder = '/content/drive/MyDrive/data/NathanGAN/train_results/NathanGAN/images'
    
    device = torch.device("cuda:0")
    if not use_cuda:
        device = torch.device("cpu")

    transform_list = [
            transforms.Resize((int(im_size),int(im_size))),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]
    trans = transforms.Compose(transform_list)
    
    if 'lmdb' in data_root:
        from operation import MultiResolutionDataset
        dataset = MultiResolutionDataset(data_root, trans, 1024)
    else:
        dataset = ImageFolder(root=data_root, transform=trans)

    dataloader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=False, 
                sampler=InfiniteSamplerWrapper(dataset), num_workers=dataloader_workers, pin_memory=True))
    '''
    loader = MultiEpochsDataLoader(dataset, batch_size=batch_size, 
                               shuffle=True, num_workers=dataloader_workers, 
                               pin_memory=True)
    dataloader = CudaDataLoader(loader, 'cuda')
    '''
    
    #from model_s import Generator, Discriminator
    netG = Generator(ngf=ngf, nz=nz, im_size=im_size)
    netG.apply(weights_init)

    netD = Discriminator(ndf=ndf, im_size=im_size)
    netD.apply(weights_init)

    netG.to(device)
    netD.to(device)

    avg_param_G = copy_G_params(netG)

    fixed_noise = torch.tensor(seed2vec(69420, 8, nz), dtype=torch.float32).to(device)
    
    optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999))

    if checkpoint != None:
        ckpt = torch.load(checkpoint, map_location='cuda')
        netG.load_state_dict(ckpt['g'])
        netD.load_state_dict(ckpt['d'])
        avg_param_G = ckpt['g_ema']
        optimizerG.load_state_dict(ckpt['opt_g'])
        optimizerD.load_state_dict(ckpt['opt_d'])
        current_iteration = int(checkpoint.split('_')[-1].split('.')[0])
        del ckpt
    
    with tqdm(total=total_iterations - current_iteration, ascii=True, smoothing=0.6) as pbar:
        for iteration in range(current_iteration, total_iterations+1, batch_size):
            real_image = next(dataloader)
            real_image = real_image.to(device)
            current_batch_size = real_image.size(0)
            noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)

            fake_images = netG(noise)

            real_image = DiffAugment(real_image, policy=policy)
            fake_images = [DiffAugment(fake, policy=policy) for fake in fake_images]
            
            ## 2. train Discriminator
            netD.zero_grad()

            err_dr, rec_img_all, rec_img_small, rec_img_part = train_d(netD, real_image, label="real")
            train_d(netD, [fi.detach() for fi in fake_images], label="fake")
            optimizerD.step()
            
            ## 3. train Generator
            netG.zero_grad()
            pred_g = netD(fake_images, "fake")
            err_g = -pred_g.mean()

            err_g.backward()
            optimizerG.step()

            for p, avg_p in zip(netG.parameters(), avg_param_G):
                avg_p.mul_(0.999).add_(0.001 * p.data)

            if iteration % 10000 < batch_size:
                print("\nGAN: loss d: %.5f    loss g: %.5f"%(err_dr, -err_g.item()))

            if iteration % (save_interval*500) < batch_size:
                backup_para = copy_G_params(netG)
                load_params(netG, avg_param_G)
                with torch.no_grad():
                    vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5), saved_image_folder+'/%d.jpg'%iteration, nrow=4)
                    vutils.save_image( torch.cat([
                            F.interpolate(real_image, 128), 
                            rec_img_all, rec_img_small,
                            rec_img_part]).add(1).mul(0.5), saved_image_folder+'/rec_%d.jpg'%iteration )
                load_params(netG, backup_para)

            if iteration % (save_interval*500) < batch_size or iteration == total_iterations:
                backup_para = copy_G_params(netG)
                # load_params(netG, avg_param_G)
                # torch.save({'g':netG.state_dict(),'d':netD.state_dict()}, saved_model_folder+'/%d.pth'%iteration)
                load_params(netG, backup_para)
                torch.save({'g':netG.state_dict(),
                            'd':netD.state_dict(),
                            'g_ema': avg_param_G,
                            'opt_g': optimizerG.state_dict(),
                            'opt_d': optimizerD.state_dict()}, saved_model_folder+'/all_%d.pth'%iteration)
                if delete_old_ckpts:
                    model_filenames = os.listdir(saved_model_folder)
                    part_model_filenames = []
                    all_model_filenames = []
                    for file in model_filenames:
                        if file == file.split('_')[0]:
                            part_model_filenames += [file]
                        else:
                            all_model_filenames += [file]

                    while len(part_model_filenames) > 2:
                        os.remove(saved_model_folder + '/' + part_model_filenames[0])
                        part_model_filenames.pop(0)

                    while len(all_model_filenames) > 2:
                        os.remove(saved_model_folder + '/' + all_model_filenames[0])
                        all_model_filenames.pop(0)

            pbar.update(batch_size)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='region gan')
    parser.add_argument('--path', type=str, default='/content/drive/MyDrive/data/NathanGAN/dataset/', help='path of resource dataset, should be a folder that has one or many sub image folders inside')
    parser.add_argument('--cuda', type=int, default=1, help='index of gpu to use')
    parser.add_argument('--name', type=str, default='NathanGAN', help='experiment name')
    parser.add_argument('--iter', type=int, default=500000, help='number of iterations')
    parser.add_argument('--start_iter', type=int, default=0, help='the iteration to start training')
    parser.add_argument('--batch_size', type=int, default=8, help='mini batch number of images')
    parser.add_argument('--im_size', type=int, default=1024, help='image resolution')
    parser.add_argument('--ckpt', type=str, help='checkpoint weight path if have one')
    args = parser.parse_args()

    if args.ckpt is None:
        dir = '/content/drive/MyDrive/data/NathanGAN/train_results/NathanGAN/models/'
        ckpt_file = os.listdir(dir)[-1]
        args.ckpt = dir + ckpt_file
        print('warning: ckpt is defaulting to ' + args.ckpt)
    assert os.path.exists(args.ckpt), 'checkpoint path does not exist'

    args.start_iter = int(ckpt_file.partition('_')[-1].partition('.')[0])
    # args.iter = args.start_iter + 4600

    print(args)
    train(args)

# Alas! we train

Additional arguments can be found in the train.py file. I prefer to change the default within the file itself although that's probably bad practice.

In [None]:
import os

checkpointdir = '/content/drive/MyDrive/data/NathanGAN/train_results/NathanGAN/models/'
latest_ckpt_file = os.listdir(dir)[-1]
CHECKPOINT = dir + latest_ckpt_file
DATASET = "/content/drive/MyDrive/data/NathanGAN/dataset"

# Build the command and run it
cmd = f"/usr/bin/python3 /content/FastGAN-pytorch/train.py --ckpt {CHECKPOINT} --path {DATASET}"
!{cmd}

... or if you prefer ...

In [None]:
!/usr/bin/python3 /content/FastGAN-pytorch/train.py 

Setting up Perceptual loss...
Loading model from: /content/FastGAN-pytorch/lpips/weights/v0.1/vgg.pth
...[net-lin [vgg]] initialized
...Done
Namespace(batch_size=8, ckpt='/content/drive/MyDrive/data/NathanGAN/train_results/NathanGAN/models/all_38000.pth', cuda=1, im_size=1024, iter=500000, name='NathanGAN', path='/content/drive/MyDrive/data/NathanGAN/dataset/', start_iter=38000)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  0% 2000/462000 [13:36<51:45:30,  2.47it/s]
GAN: loss d: 0.71223    loss g: -2.02596
  1% 4152/462000 [28:17<51:49:08,  2.45it/s]