## 1. Train lightweight gan model

See https://github.com/lucidrains/lightweight-gan

Use the following command:
```
lightweight_gan --data <path_to_your_dataset> --image-size 256  --name <class_name> --num_train_steps 20000 --save_every 500 --evaluate_every 500
```

## 2. Generate synthetic images

In [None]:
import sys
sys.path.append('../repos/lightweight-gan/') # path to lightweight-gan repository

In [None]:
import os
from tqdm.notebook import tqdm
import shutil
import numpy as np

import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage

import torch
torch.set_grad_enabled(False)
from PIL import Image

from lightweight_gan import cli


In [None]:
root = '<path_to_root_directory>'
model = cli.train_from_folder(
    data=root + '../train/double/',
    results_dir=root + './results',
    models_dir=root + './models',
    image_size=256,
    greyscale=False,
    name='<class_name>',
    just_load_model=True
)

In [None]:
model.GAN.eval()
latent_dim = model.GAN.latent_dim
ext = model.image_extension

In [None]:
import torch.nn.functional as F
from torchvision.utils import make_grid
import torchvision

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(figsize=(20, 20), ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = torchvision.transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
imgs = []
pil_imgs = []
torch.manual_seed(1) # Seed for double images
latents = torch.randn((25, latent_dim)).cuda()

for i in range(25):
    generated_image = model.generate_(model.GAN.G, latents[i][None, :]) 
    img = generated_image.cpu()[0]
    imgs.append(img)
    
    pil_img = torchvision.transforms.functional.to_pil_image(img)
    pil_imgs.append(pil_img)
show(make_grid(imgs, nrow=5, normalize=False))

In [None]:
torch.manual_seed(1) # Seed for double images
latents = torch.randn((600, latent_dim)).cuda()

for i in tqdm(range(600)):
    generated_image = model.generate_(model.GAN.G, latents[i][None, :]) 
    img = generated_image.cpu()[0]
    pil_img = torchvision.transforms.functional.to_pil_image(img)
    pil_img.save('fake/%d.png' % i)
