In [None]:
pip install fid-score

In [None]:
!git clone https://github.com/JunmingZhang/BigGAN-image-generation.git

In [None]:
cd BigGAN-image-generation/

In [None]:
!gdown https://drive.google.com/uc?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW

In [None]:
!unzip BigGAN_ch96_bs256x8_138k.zip

In [6]:
mv 138k/G_ema.pth data/G_ema.pth

<h1> Train </h1>

In [None]:
# !python train.py --dataset flowers --gpu 0 --pretrained ./data/G_ema.pth --iters 500
# !python train.py --dataset anime --gpu 0 --pretrained ./data/G_ema.pth --resume "/content/checkpoint_anime_iter7500.pth.tar" --iters 10000

<h1> Test </h1>

In [9]:
import glob
import os
import shutil
import matplotlib
from PIL import Image
import numpy as np
import json
import shutil

%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torchvision
from models.setup_model import setup_model
from dataloaders.setup_dataloader_smallgan import setup_dataloader

from fid_score import fid_score

def reconstruct(model,out_path,indices):
    model.eval()
    device = next(model.parameters()).device
    dataset_size = model.embeddings.weight.size()[0]
    assert type(indices)==torch.Tensor
    indices = indices.to(device)        
    embeddings = model.embeddings(indices)
    batch_size = embeddings.size()[0]
    image_tensors = model(embeddings)
    with torch.no_grad():
        torchvision.utils.save_image(
            image_tensors,
            out_path,
            nrow=int(batch_size ** 0.5),
            normalize=True,
        )
        
        
#see https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L37-L53
def interpolate(model, iteration, input_path, out_path, source, dist, trncate=0.4, num=5):
    model.eval()
    device = next(model.parameters()).device
    dataset_size = model.embeddings.weight.size()[0]
    indices = torch.tensor([source,dist],device=device)
    indices = indices.to(device) 
    embeddings = model.embeddings(indices)
    embeddings = embeddings[[0]] * torch.linspace(1, 0, num,device=device)[:, None] + embeddings[[1]]* torch.linspace(0, 1, num,device=device)[:, None]
    batch_size = embeddings.size()[0]
    image_tensors = model(embeddings)
    print(image_tensors.shape)
    print(batch_size)
    if not os.path.exists(out_path + str(iteration)):
        os.makedirs(out_path + str(iteration) + "/")
    out_path = out_path + str(iteration) + "/"
    inputs = os.listdir(input_path)
    with torch.no_grad():
        for i in range(len(image_tensors)):
            torchvision.utils.save_image(
                image_tensors[i],
                out_path + inputs[i],
                nrow=batch_size,
                normalize=True,
            )

#from https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L37-L53        
def random(model, iteration, input_path, out_path, tmp=0.4, n=9, truncate=False):
    from scipy.stats import truncnorm
    model.eval()
    device = next(model.parameters()).device
    dataset_size = model.embeddings.weight.size()[0]
    dim_z = model.embeddings.weight.size(1)
    if truncate:
        embeddings = truncnorm(-tmp, tmp).rvs(n * dim_z).astype("float32").reshape(n, dim_z)
    else:
        embeddings = np.random.normal(0, tmp, size=(n, dim_z)).astype("float32")
    embeddings = torch.tensor(embeddings,device=device)
    batch_size = embeddings.size()[0]
    image_tensors = model(embeddings)
    print(image_tensors.shape)
    print(batch_size)
    print(device)
    if not os.path.exists(out_path + str(iteration)):
        os.makedirs(out_path + str(iteration) + "/")
    out_path = out_path + str(iteration) + "/"
    inputs = os.listdir(input_path)
    with torch.no_grad():
        for i in range(len(image_tensors)):
            torchvision.utils.save_image(
                image_tensors[i],
                out_path + inputs[i],
                nrow=int(batch_size ** 0.5),
                normalize=True,
            )

In [10]:
# Although we have interpolate and random here, we only focus on the reconstruction in the research
if os.path.exists('biggan_generations'):
    shutil.rmtree('/content/BigGAN-image-generation/biggan_generations')
if not os.path.exists('biggan_generations'):
    os.makedirs('biggan_generations')

    os.makedirs('biggan_generations/reconstruct')
    os.makedirs('biggan_generations/reconstruct/anime')
    os.makedirs('biggan_generations/reconstruct/face')
    os.makedirs('biggan_generations/reconstruct/flowers')

    os.makedirs('biggan_generations/interpolate')
    os.makedirs('biggan_generations/interpolate/anime')
    os.makedirs('biggan_generations/interpolate/face')
    os.makedirs('biggan_generations/interpolate/flowers')

    os.makedirs('biggan_generations/random')
    os.makedirs('biggan_generations/random/anime')
    os.makedirs('biggan_generations/random/face')
    os.makedirs('biggan_generations/random/flowers')

In [None]:
# Specify the target dataset and iteration here
dataset = "flowers"
iteration = 5000
l = len(os.listdir("data/" + dataset))
print(l)
dataloader = setup_dataloader(dataset, batch_size=2)

# Uncomment the following two lines if you are using the trained models obtained by previous training commands.
# exp_dir = "./experiments/train_dataset-flowers_model-biggan128-ada_2021-04-13-17-17-00/"
# model = setup_model("biggan128-ada", dataset_size=l, resume=exp_dir+"checkpoint_" + dataset + "_iter" + str(iteration) + ".pth.tar")

# Uncomment the following two lines if you are using the trained models already in ./pretrained/
# We put the trained models in this folder while testing, sometimes they were trained few hours ago from the testing time.
# You should create the new folder "pretrained" by yourself, this folder does not exist by default
# exp_dir = "./pretrained/"
# model = setup_model("biggan128-ada", dataset_size=l, resume=exp_dir + "checkpoint_" + dataset + "_iter" + str(iteration) + ".pth.tar")

model = model.cuda()

In [None]:
reconstruct(model, iteration, input_path="./data/" + dataset + "/", out_path="./biggan_generations/reconstruct/" + dataset + "/", indices=torch.arange(l))
interpolate(model, iteration, input_path="./data/" + dataset + "/", out_path="./biggan_generations/interpolate/" + dataset + "/", source=1, dist=2)
random(model, iteration, input_path="./data/" + dataset + "/", out_path="./biggan_generations/random/" + dataset + "/", tmp=0.2, n=l, truncate=True)