In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import torch.nn as nn
from pathlib import Path
import pickle
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from lib.data import (
    ImageTokenDataset,
    ImageTokenDatasetClassLabel,
    ImageTokenDatasetSemanticLabel
)
from lib.models import (
    ConditionalTransformerDecoderConfig, ConditionalTransformerDecoder,
    VanillaTransformerDecoderConfig, VanillaTransformerDecoder
)
from lib.training import (
    ConditionalTransformerTrainer,
    UnconditionalTransformerTrainer
)

device = torch.device("cuda:0")

In [2]:
ckpt_root = dict(
    cond_l = "./log/run_1/ckpt_e600_condl.pt",
    cond_s = "./log/run_1/ckpt_e600_conds.pt",
    unc = "./log/run_1/ckpt_e600_unc.pt"
)

#### Unconditional Token Sampling

In [3]:
ckpt = torch.load(ckpt_root["unc"], weights_only=False, map_location="cpu")

In [4]:
config = VanillaTransformerDecoderConfig(**ckpt["model_config"])
model = VanillaTransformerDecoder(config)
model.to(device)
model.load_state_dict(ckpt["model_state_dict"])

#params: 38878720


<All keys matched successfully>

In [5]:
with torch.no_grad():
    model.eval()

    sampled_images = []
    for _ in tqdm(range(50)):
        x = model.sos_token_id[None]
        for _ in range(256):
            pred = model.predict_next_token(x, 10)
            x = torch.cat([x, torch.tensor(pred, device=device)[None]])
        sampled_images.append(x[1:].cpu().numpy())
    sampled_images = np.vstack(sampled_images)
    sampled_images.shape

100%|██████████| 50/50 [02:08<00:00,  2.57s/it]


In [6]:
savepath = Path("C:/Users/marco/Desktop/projects/taming-transformer/taming-transformers/sampled_images/unc/sample.npy")
savepath.parent.mkdir(parents=True, exist_ok=True)
np.save(savepath, sampled_images)

#### Label-Conditional Token Sampling

In [7]:
ckpt = torch.load(ckpt_root["cond_l"], weights_only=False, map_location="cpu")

In [8]:
config = ConditionalTransformerDecoderConfig(**ckpt["model_config"])
model = ConditionalTransformerDecoder(config)
model.to(device)
model.load_state_dict(ckpt["model_state_dict"])

#params: 43112448


<All keys matched successfully>

In [9]:
target_class_labels = np.random.choice(config.n_classes, size=10, replace=False)

with torch.no_grad():
    model.eval()

    sampled_images = {i:[] for i in target_class_labels}

    for target_class_id in target_class_labels:

        prompt = model.class_prompt_embedding(
            torch.tensor(target_class_id, device=device)[None]
        ).view(config.class_prompt_length, -1)

        for _ in tqdm(range(5)):
            x = torch.tensor([], device=device).long()
            for _ in range(256):
                pred = model.predict_next_token(prompt, x, 10)
                x = torch.cat([x, torch.tensor(pred, device=device)[None]])
            sampled_images[target_class_id].append(x.cpu().numpy())
        sampled_images[target_class_id] = np.vstack(sampled_images[target_class_id])

100%|██████████| 5/5 [00:13<00:00,  2.64s/it]
100%|██████████| 5/5 [00:13<00:00,  2.64s/it]
100%|██████████| 5/5 [00:13<00:00,  2.64s/it]
100%|██████████| 5/5 [00:13<00:00,  2.66s/it]
100%|██████████| 5/5 [00:13<00:00,  2.65s/it]
100%|██████████| 5/5 [00:13<00:00,  2.66s/it]
100%|██████████| 5/5 [00:13<00:00,  2.66s/it]
100%|██████████| 5/5 [00:13<00:00,  2.64s/it]
100%|██████████| 5/5 [00:13<00:00,  2.64s/it]
100%|██████████| 5/5 [00:13<00:00,  2.67s/it]


In [10]:
savepath = Path("C:/Users/marco/Desktop/projects/taming-transformer/taming-transformers/sampled_images/cond_l/sample.npy")
savepath.parent.mkdir(parents=True, exist_ok=True)
sampled_images = {int(k):v.tolist() for k,v in sampled_images.items()}
np.save(savepath, sampled_images)

savepath = Path("C:/Users/marco/Desktop/projects/taming-transformer/taming-transformers/sampled_images/cond_l/sample.pickle")
pickle.dump(sampled_images, open(savepath, "wb"), protocol=pickle.HIGHEST_PROTOCOL)

#### Semantic-Conditional Token Sampling

In [11]:
ckpt = torch.load(ckpt_root["cond_s"], weights_only=False, map_location="cpu")

In [12]:
config = ConditionalTransformerDecoderConfig(**ckpt["model_config"])
model = ConditionalTransformerDecoder(config)
model.to(device)
model.load_state_dict(ckpt["model_state_dict"])

#params: 38918144


<All keys matched successfully>

In [13]:
target_class_labels = np.random.choice(config.n_classes, size=5, replace=False)

with torch.no_grad():
    model.eval()

    sampled_images = {i:[] for i in target_class_labels}

    for target_class_id in target_class_labels:

        prompt = model.class_prompt_embedding(
            torch.tensor(target_class_id, device=device)[None]
        ).view(config.class_prompt_length, -1)

        for _ in tqdm(range(10)):
            x = torch.tensor([], device=device).long()
            for _ in range(256):
                pred = model.predict_next_token(prompt, x, 10)
                x = torch.cat([x, torch.tensor(pred, device=device)[None]])
            sampled_images[target_class_id].append(x.cpu().numpy())
        sampled_images[target_class_id] = np.vstack(sampled_images[target_class_id])

100%|██████████| 10/10 [00:26<00:00,  2.62s/it]
100%|██████████| 10/10 [00:26<00:00,  2.61s/it]
100%|██████████| 10/10 [00:26<00:00,  2.62s/it]
100%|██████████| 10/10 [00:26<00:00,  2.63s/it]
100%|██████████| 10/10 [00:26<00:00,  2.62s/it]


In [14]:
savepath = Path("C:/Users/marco/Desktop/projects/taming-transformer/taming-transformers/sampled_images/cond_s/sample.npy")
savepath.parent.mkdir(parents=True, exist_ok=True)
sampled_images = {int(k):v.tolist() for k,v in sampled_images.items()}
np.save(savepath, sampled_images)

savepath = Path("C:/Users/marco/Desktop/projects/taming-transformer/taming-transformers/sampled_images/cond_s/sample.pickle")
pickle.dump(sampled_images, open(savepath, "wb"), protocol=pickle.HIGHEST_PROTOCOL)