# How to generate Image with our dall-e model

In [None]:

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid

# dalle related classes and utils

from dalle_pytorch import VQGanVAE
from dalle_pytorch.tokenizer import *

In [None]:
# 생성하고싶은 이미지의 설명
raw_text = "아우터는 색상이 브라운 소재가 우븐 핏이 노멀인 재킷이다. 하의는 색상이 블랙 소재가 우븐 핏이 노멀인 팬츠이다."

## Settings

In [None]:
from easydict import EasyDict
DALLE_PATH = './dalle.pt'

DALLE_CFG = EasyDict()

# argument parsing

DALLE_CFG.VQGAN_PATH = "./VQGAN_blue_e7"   # './vae.pt' - will use OpenAIs pretrained VAE if not set
DALLE_CFG.VQGAN_CFG_PATH = "./VQGAN_blue.yaml"   # './vae.pt' - will use OpenAIs pretrained VAE if not set
DALLE_CFG.DALLE_PATH = "./dalle.pt"   # './vae.pt' - will use OpenAIs pretrained VAE if not set

DALLE_CFG.WPE_PATH = "./roberta_large_wpe.pt"
DALLE_CFG.WTE_PATH = "./roberta_large_wte.pt"

# DALLE_CFG.MODEL_DIM = 512
DALLE_CFG.TEXT_SEQ_LEN = 128

# Top-k level
DALLE_CFG.TOP_K=5

In [None]:
from transformers import AutoTokenizer
  
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")

encoded_dict = tokenizer(
    raw_text,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=DALLE_CFG.TEXT_SEQ_LEN,
    add_special_tokens=True,
    return_token_type_ids=True,  # for RoBERTa
)
encoded_dict=encoded_dict.to('cuda')

## Load Model

In [None]:
from dalle_pytorch import VQGanVAE
from pathlib import Path

DALLE_CFG.VOCAB_SIZE = tokenizer.vocab_size # refer to EDA, there are only 333 words total. but input_ids index should be in within 0 ~ 52000: https://github.com/boostcampaitech2-happyface/DALLE-Couture/blob/pytorch-dalle/EDA.ipynb


loaded_obj = torch.load(DALLE_PATH, map_location=torch.device('cuda'))

dalle_params, _ , weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']

vae_klass = VQGanVAE
vae = vae_klass(
    vqgan_model_path=DALLE_CFG.VQGAN_PATH, 
    vqgan_config_path=DALLE_CFG.VQGAN_CFG_PATH
    )

DALLE_CFG.IMAGE_SIZE = vae.image_size

dalle_params = dict(        
    **dalle_params
)

DALLE_CFG.IMAGE_SIZE = vae.image_size

In [None]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Image generate

In [None]:
from model import DALLE_Klue_Roberta
from clip.clipmodel import *

dalle = DALLE_Klue_Roberta(
    vae = vae, 
    wte_dir=DALLE_CFG.WTE_PATH,
    wpe_dir=DALLE_CFG.WPE_PATH,
    **dalle_params
    )
dalle.load_state_dict(weights)
dalle.to('cuda')
#encoded_dict=repeat(encoded_dict,'() n -> b n',b=DALLE_CFG.TOP_K)
# https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L454-L510
clip_model = torch.load("clip.pt",map_location=torch.device('cuda'))
clip_model.to('cuda')

In [None]:
images = dalle.generate_images(encoded_dict,clip=clip_model,img_num=DALLE_CFG.TOP_K)

## Display Image

In [None]:
from torchvision.utils import make_grid, save_image
from PIL import Image
lgits=images[1]
probs = torch.nn.Softmax(dim=1)(lgits)
import matplotlib.pyplot as plt
print(f'Input Text: {raw_text}')
### Sorting 통해 text랑 image사이 거리가 가장 가까운 순으로 출력 ### 
for idx, prob in sorted(enumerate(probs[0]),key = lambda x: x[1], reverse = True):
    print(f'probability: {prob.item()}')
    grid = make_grid(images[0][idx], nrow=1, padding=0, pad_value=0)
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    display(im)

In [None]:
print(images[1])