In [53]:
import torch
import torch.nn.functional as F
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from DalleDecoder import DecoderOnlyTransformer
from DallEdVAE import dVAE
import pickle
import matplotlib.pyplot as plt

In [54]:
path_to_vocab = 'path_to_vocab.pkl'
path_to_dvae = 'path_to_pretrained_dVAE'
path_to_transformer = 'path_to_pretrained_transformer'

tokenizer = get_tokenizer('basic_english')

with open(path_to_vocab, 'rb') as f:
    vocab = pickle.load(f)

text_vocab_size = len(vocab)
text_seq_len = 256
total_len_text_vocab = text_vocab_size + text_seq_len
image_seq_len = 1024

def sent_padding(sent_vec, maxlen):
    sent_vec = torch.tensor(sent_vec)
    maxlen -= len(sent_vec)
    return F.pad(sent_vec, (0, maxlen))

def text2token(text):
    text_vector = sent_padding(vocab(tokenizer(text)), maxlen=text_seq_len)
    text_range = torch.arange(text_seq_len) + text_vocab_size
    text = torch.where(text_vector == 0, text_range, text_vector) # tokens
    text = F.pad(text, (1, 0), value = 0) # add <bos>
    return text

In [55]:
inp_ch = 3
n_hid = 256
n_init = 128
bpg = 2
K = 8192
D = 512
Beta = 6.6

dvae = dVAE(inp_ch, n_hid, n_init, bpg, K, D, Beta)
dvae.load_state_dict(torch.load(path_to_dvae))

In [56]:
image_vocab_size = 8192
image_seq_len = 1024
d_model = 512
N = 64
heads = 64
d_ff = 2048

transformer = DecoderOnlyTransformer(text_vocab_size, text_seq_len, image_vocab_size,
                                     image_seq_len, d_model, N, heads, d_ff)

transformer.load_state_dict(torch.load(path_to_transformer))

In [None]:
@torch.no_grad()
def generation(text):

    image_tokens = torch.Tensor(image_seq_len).fill_(0).to(torch.long).unsqueeze(0) # [1, 1024]
    text_tokens = text2token(text).unsqueeze(0) # [1, 257]

    for i in range(image_seq_len): # 1024
        logits = transformer(text_tokens, image_tokens) 
        # logits: [1, 1280, 16384 + 256 + 8192]
        logits = logits[:, text_seq_len + i, :] 
        # logits: [1, 1, 16384 + 256 + 8192]
        next_img_token = torch.argmax(logits, dim=-1)
        assert next_img_token - total_len_text_vocab >= 0, "ERROR, BAD TRANSFORMER!"
        image_tokens[:, i] = next_img_token - total_len_text_vocab # offset reverse

    one_hot = torch.zeros(image_seq_len, K)
    one_hot.scatter_(1, image_tokens, 1)
    quantized = torch.matmul(one_hot, dvae.QNTZ.embedding.weight).reshape((1, 32, 32, D))
    z_q_x = quantized.permute(0, 3, 1, 2)
    generated_image = dvae.decoder(z_q_x).squeeze(0)

    return generated_image

input_text = 'a dog in the water'
generated_image = generation(input_text)

plt.imshow(generated_image.permute(1, 2, 0))
plt.show()