<a href="https://colab.research.google.com/github/FabioBoccia/Progetto_ESM/blob/main/Jupyter%20Notebooks/generate_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Taming Transformers

This notebook is a minimal working example to generate faces images as in [Taming Transformers for High-Resolution Image Synthesis](https://github.com/CompVis/taming-transformers).

## Setup
The setup code in this section was written to be [run in a Colab environment](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
It should be run only once per runtime.
We first clone the repository and upload the model checkpoint and config from Google Drive

In [2]:
%cd /content
%rm -r sample_data
!git clone https://github.com/CompVis/taming-transformers
%cd taming-transformers
!mkdir -p logs

from google.colab import drive
drive.mount('/content/drive')
!cp -r /content/drive/MyDrive/2021-04-23T18-19-01_ffhq_transformer logs/

/content
rm: cannot remove 'sample_data': No such file or directory
fatal: destination path 'taming-transformers' already exists and is not an empty directory.
/content/taming-transformers
Mounted at /content/drive


Next, we install required dependencies.

In [3]:
!pip install --upgrade omegaconf einops transformers pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting omegaconf
  Downloading omegaconf-2.2.2-py3-none-any.whl (79 kB)
[K     |████████████████████████████████| 79 kB 3.8 MB/s 
[?25hCollecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting transformers
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 31.4 MB/s 
[?25hCollecting pytorch-lightning
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 60.9 MB/s 
[?25hCollecting PyYAML>=5.1.0
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 67.0 MB/s 
[?25hCollecting antlr4-python3-runtime==4.9.*
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 65.2 M

## Loading the model

We load and print the config.

In [4]:
import sys
%cd /content/taming-transformers
sys.path.append(".")
from omegaconf import OmegaConf
config_path = "/content/taming-transformers/logs/2021-04-23T18-19-01_ffhq_transformer/configs/2021-04-23T18-19-01-project.yaml"
config = OmegaConf.load(config_path)
import yaml
print(yaml.dump(OmegaConf.to_container(config)))

/content/taming-transformers
data:
  params:
    batch_size: 24
    num_workers: 24
    train:
      params:
        size: 256
      target: taming.data.faceshq.FFHQTrain
    validation:
      params:
        size: 256
      target: taming.data.faceshq.FFHQValidation
  target: cutlit.DataModuleFromConfig
model:
  base_learning_rate: 0.0625
  params:
    cond_stage_config: __is_unconditional__
    first_stage_config:
      params:
        ddconfig:
          attn_resolutions:
          - 16
          ch: 128
          ch_mult:
          - 1
          - 1
          - 2
          - 2
          - 4
          double_z: false
          dropout: 0.0
          in_channels: 3
          num_res_blocks: 2
          out_ch: 3
          resolution: 256
          z_channels: 256
        embed_dim: 256
        lossconfig:
          target: taming.modules.losses.vqperceptual.DummyLoss
        n_embed: 1024
      target: taming.models.vqgan.VQModel
    first_stage_key: image
    transformer_config:
   

Instantiate the model and load the checkpoint.

In [5]:
from taming.models.cond_transformer import Net2NetTransformer
import torch
model = Net2NetTransformer(**config.model.params)
ckpt_path = "/content/taming-transformers/logs/2021-04-23T18-19-01_ffhq_transformer/checkpoints/last.ckpt"
sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Using no cond stage. Assuming the training is intended to be unconditional. Prepending 0 as a sos token.


In [6]:
model.cuda().eval()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fe9bd1826d0>

## Tuning Hyperparameters


In [7]:
%cd /content
!wget --no-check-certificate --user=corso --password=p2021corso http://www.grip.unina.it/download/corso/ffhq_real.zip
!unzip -q ffhq_real.zip
#!rm -d 0_real
#!rm ffhq_real.zip
%mkdir -p /content/ffhq/real
%mkdir -p /content/ffhq/synthesized

/content
--2022-06-01 18:03:49--  http://www.grip.unina.it/download/corso/ffhq_real.zip
Resolving www.grip.unina.it (www.grip.unina.it)... 143.225.28.237
Connecting to www.grip.unina.it (www.grip.unina.it)|143.225.28.237|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.grip.unina.it/download/corso/ffhq_real.zip [following]
--2022-06-01 18:03:49--  https://www.grip.unina.it/download/corso/ffhq_real.zip
Connecting to www.grip.unina.it (www.grip.unina.it)|143.225.28.237|:443... connected.
  Unable to locally verify the issuer's authority.
HTTP request sent, awaiting response... 401 Unauthorized
Authentication selected: Basic realm="corso"
Reusing existing connection to www.grip.unina.it:443.
HTTP request sent, awaiting response... 200 OK
Length: 4072216854 (3.8G) [application/zip]
Saving to: ‘ffhq_real.zip’


2022-06-01 18:06:24 (25.2 MB/s) - ‘ffhq_real.zip’ saved [4072216854/4072216854]



In [8]:
%cd /content/taming-transformers
from scripts.sample_fast import run
from tqdm import tqdm
from PIL import Image
import numpy as np
import os 
import tensorflow as tf
%cd /content

/content/taming-transformers
/content


In [9]:
## 
batch_size = 3

trainmodel = tf.keras.models.load_model('/content/drive/MyDrive/my_keras_model.h5')

logdir="/content/ffhq/synthesized/"

def loss(k, p, temp):
    %rm /content/ffhq/synthesized/*
    %cd /content/ffhq/real/
    %mv * /content/0_real/
    %cd /content/0_real/
    %mv $(ls | shuf -n $num_samples) ../ffhq/real/
    run(logdir, model, batch_size, temp, k, unconditional=model.be_unconditional, num_samples=num_samples, top_p=p)
    dataset = tf.keras.utils.image_dataset_from_directory(
        '/content/ffhq/',
        labels='inferred',
        label_mode='categorical',
        color_mode='rgb',
        batch_size=5,
        image_size=(256, 256),
        shuffle=True,
        validation_split=0,
        interpolation='bilinear',
        crop_to_aspect_ratio=True,
    )
    loss, AUC, accuracy = trainmodel.evaluate(dataset)
    print('||tested with top_k:', k, '- top_p:',p,'- temperature: ',t)
    return AUC

In [12]:
# minimize AUC (loss function)
import csv

num_samples = 50
ks = np.arange(185, 226, 10) # 205
ps = np.arange(0.82, 0.91, 0.02) # 0.88
ts = np.arange(1.12, 1.25, 0.03) # 1.15


%cd /content/drive/MyDrive
for x in range(5):
    for y in range(5):
        for z in range(5):
            k = int(ks[z])
            p = ps[y]
            t = ts[x]
            l = loss(k, p, t)
            print("n:",25*z+5*y+x+1)
            test = (k, p, t, l)
            with open ('/content/drive/MyDrive/export3.csv','a',newline = '') as csvfile:
                my_writer = csv.writer(csvfile, delimiter = ',')
                my_writer.writerow(test)
                csvfile.close()

/content/drive/MyDrive
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:32,  7.72s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 185 - top_p: 0.82 - temperature:  1.2400000000000002
n: 1
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:33,  7.82s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 195 - top_p: 0.82 - temperature:  1.2400000000000002
n: 2
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 205 - top_p: 0.82 - temperature:  1.2400000000000002
n: 3
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:36,  8.06s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 215 - top_p: 0.82 - temperature:  1.2400000000000002
n: 4
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.99s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 225 - top_p: 0.82 - temperature:  1.2400000000000002
n: 5
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.95s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 185 - top_p: 0.84 - temperature:  1.2400000000000002
n: 4
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 195 - top_p: 0.84 - temperature:  1.2400000000000002
n: 8
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 205 - top_p: 0.84 - temperature:  1.2400000000000002
n: 12
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 215 - top_p: 0.84 - temperature:  1.2400000000000002
n: 16
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 225 - top_p: 0.84 - temperature:  1.2400000000000002
n: 20
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 185 - top_p: 0.86 - temperature:  1.2400000000000002
n: 9
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 195 - top_p: 0.86 - temperature:  1.2400000000000002
n: 18
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.95s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 205 - top_p: 0.86 - temperature:  1.2400000000000002
n: 27
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 215 - top_p: 0.86 - temperature:  1.2400000000000002
n: 36
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 225 - top_p: 0.86 - temperature:  1.2400000000000002
n: 45
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 185 - top_p: 0.88 - temperature:  1.2400000000000002
n: 16
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.98s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 195 - top_p: 0.88 - temperature:  1.2400000000000002
n: 32
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 205 - top_p: 0.88 - temperature:  1.2400000000000002
n: 48
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.99s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 215 - top_p: 0.88 - temperature:  1.2400000000000002
n: 64
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 225 - top_p: 0.88 - temperature:  1.2400000000000002
n: 80
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 185 - top_p: 0.9 - temperature:  1.2400000000000002
n: 25
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.97s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 195 - top_p: 0.9 - temperature:  1.2400000000000002
n: 50
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 205 - top_p: 0.9 - temperature:  1.2400000000000002
n: 75
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.96s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 215 - top_p: 0.9 - temperature:  1.2400000000000002
n: 100
/content/ffhq/real
/content/0_real
Running in unconditional sampling mode, producing 35 samples.


Sampling: 12it [01:35,  7.95s/it]


Found 70 files belonging to 2 classes.
||tested with top_k: 225 - top_p: 0.9 - temperature:  1.2400000000000002
n: 125


## Generate image samples
Generated images will be saved in `logdir/samples/top_k_{top_k}_top_p_{top_p}/`
as PNGs

if you wish to display the generated images instead of saving them in logdir, 
set save = False


In [None]:
num_samples = 1000
batch_size = 3
top_k = 300
top_p = 1.0
temperature = 1.0

save = True
logdir="/content/drive/MyDrive/"

if (not a.isEmpty):
    top_k = a[0].k
    top_p = a[0].p 
    temperature = a[0].t

In [None]:
from scripts.sample_fast import sample_unconditional
from tqdm import tqdm
from PIL import Image
import numpy as np

def show_image(s):
  s = s.detach().cpu().numpy().transpose(0,2,3,1)[0]
  s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
  s = Image.fromarray(s)
  display(s)

@torch.no_grad()
def run_live(model, batch_size, temperature, top_k, unconditional=True, num_samples=50000, top_p=None):
    print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
    batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
    for n, bs in tqdm(enumerate(batches), desc="Sampling"):
        if bs == 0: break
        logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
        show_image(logs["samples"])

In [None]:
from scripts.sample_fast import run
import os 

if save == True:
    logdir = os.path.join(logdir, "samples", f"top_k_{top_k}_top_p_{top_p}")
    print(f"Logging to {logdir}")
    os.makedirs(logdir, exist_ok=True)
    run(logdir, model, batch_size, temperature, top_k, unconditional=model.be_unconditional, num_samples=num_samples, top_p=top_p)
else:
    print(f"Generating with top_k = {top_k}, top_p = {top_p}")
    run_live(model, batch_size, temperature, top_k, unconditional=model.be_unconditional, num_samples=num_samples, top_p=top_p)