<a href="https://colab.research.google.com/github/FabioBoccia/Progetto_ESM/blob/generation-hyperparameters-tuning/generate_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Taming Transformers

This notebook is a minimal working example to generate faces images as in [Taming Transformers for High-Resolution Image Synthesis](https://github.com/CompVis/taming-transformers).

## Setup
The setup code in this section was written to be [run in a Colab environment](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
It should be run only once per runtime.
We first clone the repository and upload the model checkpoint and config from Google Drive

In [None]:
%cd /content
%rm -r sample_data
!git clone https://github.com/CompVis/taming-transformers
%cd taming-transformers
!mkdir -p logs

from google.colab import drive
drive.mount('/content/drive')
!cp -r /content/drive/MyDrive/2021-04-23T18-19-01_ffhq_transformer logs/

/content
Cloning into 'taming-transformers'...
remote: Enumerating objects: 1335, done.[K
remote: Total 1335 (delta 0), reused 0 (delta 0), pack-reused 1335[K
Receiving objects: 100% (1335/1335), 409.77 MiB | 49.44 MiB/s, done.
Resolving deltas: 100% (277/277), done.
/content/taming-transformers


MessageError: ignored

Next, we install required dependencies.

In [None]:
!pip install --upgrade omegaconf einops transformers pytorch-lightning

## Loading the model

We load and print the config.

In [None]:
import sys
%cd /content/taming-transformers
sys.path.append(".")
from omegaconf import OmegaConf
config_path = "/content/taming-transformers/logs/2021-04-23T18-19-01_ffhq_transformer/configs/2021-04-23T18-19-01-project.yaml"
config = OmegaConf.load(config_path)
import yaml
print(yaml.dump(OmegaConf.to_container(config)))

Instantiate the model and load the checkpoint.

In [None]:
from taming.models.cond_transformer import Net2NetTransformer
import torch
model = Net2NetTransformer(**config.model.params)
ckpt_path = "/content/taming-transformers/logs/2021-04-23T18-19-01_ffhq_transformer/checkpoints/last.ckpt"
sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)

In [None]:
model.cuda().eval()
torch.set_grad_enabled(False)

## Tuning Hyperparameters


In [None]:
%cd /content
!wget --no-check-certificate --user=corso --password=p2021corso http://www.grip.unina.it/download/corso/ffhq_real.zip
!unzip -q ffhq_real.zip
#!rm -d 0_real
#!rm ffhq_real.zip
%mkdir -p /content/ffhq/real
%mkdir -p /content/ffhq/synthesized

In [None]:
%cd /content/taming-transformers
from scripts.sample_fast import run
from tqdm import tqdm
from PIL import Image
import numpy as np
import os 
import tensorflow as tf
%cd /content

In [None]:
## 
batch_size = 3

trainmodel = tf.keras.models.load_model('/content/drive/MyDrive/my_keras_model.h5')

logdir="/content/ffhq/synthesized/"

def loss(k, p, temp):
    %rm /content/ffhq/synthesized/*
    %cd /content/ffhq/real/
    %mv * /content/0_real/
    %cd /content/0_real/
    %mv $(ls | shuf -n $num_samples) ../ffhq/real/
    run(logdir, model, batch_size, temp, k, unconditional=model.be_unconditional, num_samples=num_samples, top_p=p)
    dataset = tf.keras.utils.image_dataset_from_directory(
        '/content/ffhq/',
        labels='inferred',
        label_mode='categorical',
        color_mode='rgb',
        batch_size=5,
        image_size=(256, 256),
        shuffle=True,
        validation_split=0,
        interpolation='bilinear',
        crop_to_aspect_ratio=True,
    )
    loss, AUC, accuracy = trainmodel.evaluate(dataset)
    print('||tested with top_k:', k, '- top_p:',p,'- temperature: ',t)
    return AUC

In [None]:
# minimize AUC (loss function)
import csv

num_samples = 35
ks = np.arange(220, 421, 20)
ps = np.arange(0.8, 1.0, 0.05)
ts = np.arange(0.85, 1.15, 0.05)


%cd /content/drive/MyDrive
for n in range(500):
    k = int(np.random.choice(ks))
    p = np.random.choice(ps)
    t = np.random.choice(ts)
    l = loss(k, p, t)
    print("n:",n+1)
    test = (k, p, t, l)
    with open ('/content/drive/MyDrive/export.csv','a',newline = '') as csvfile:
        my_writer = csv.writer(csvfile, delimiter = ',')
        my_writer.writerow(test)
        csvfile.close()

## Generate image samples
Generated images will be saved in `logdir/samples/top_k_{top_k}_top_p_{top_p}/`
as PNGs

if you wish to display the generated images instead of saving them in logdir, 
set save = False


In [None]:
num_samples = 1000
batch_size = 3
top_k = 300
top_p = 1.0
temperature = 1.0

save = True
logdir="/content/drive/MyDrive/"

if (not a.isEmpty):
    top_k = a[0].k
    top_p = a[0].p 
    temperature = a[0].t

In [None]:
from scripts.sample_fast import sample_unconditional
from tqdm import tqdm
from PIL import Image
import numpy as np

def show_image(s):
  s = s.detach().cpu().numpy().transpose(0,2,3,1)[0]
  s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
  s = Image.fromarray(s)
  display(s)

@torch.no_grad()
def run_live(model, batch_size, temperature, top_k, unconditional=True, num_samples=50000, top_p=None):
    print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
    batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
    for n, bs in tqdm(enumerate(batches), desc="Sampling"):
        if bs == 0: break
        logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
        show_image(logs["samples"])

In [None]:
from scripts.sample_fast import run
import os 

if save == True:
    logdir = os.path.join(logdir, "samples", f"top_k_{top_k}_top_p_{top_p}")
    print(f"Logging to {logdir}")
    os.makedirs(logdir, exist_ok=True)
    run(logdir, model, batch_size, temperature, top_k, unconditional=model.be_unconditional, num_samples=num_samples, top_p=top_p)
else:
    print(f"Generating with top_k = {top_k}, top_p = {top_p}")
    run_live(model, batch_size, temperature, top_k, unconditional=model.be_unconditional, num_samples=num_samples, top_p=top_p)