In [2]:
import json
import numpy
import torch
import torch_tensorrt
import torchvision
from tqdm import tqdm
from dalle import VQGanDetokenizer, TextTokenizer
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Error)

with open('models/vocab.json', 'r', encoding='utf8') as f:
    vocab = json.load(f)
with open('models/merges.txt', 'r', encoding='utf8') as f:
    merges = f.read().split("\n")[1:-1]
tokenizer = TextTokenizer(vocab, merges)
image_count = 4

trt_encoder0 = torch.jit.load("/dev/shm/encoder0.ts")
trt_encoder1 = torch.jit.load("/dev/shm/encoder1.ts")

In [3]:
image_count = 4
tokens = tokenizer.tokenize('cat face in sunglasses', is_verbose=False)[:64]
tokens

[0, 803, 1775, 91, 7134, 2]

In [4]:
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
text_tokens = torch.tensor(
    text_tokens, 
    dtype=torch.int32, 
    device='cuda'
)
attention_mask = text_tokens.not_equal(1).half()

encoder_state = trt_encoder0(text_tokens.to(torch.int32), attention_mask)
encoder_state = trt_encoder1(encoder_state.half(), attention_mask)

In [5]:
del trt_encoder0
del trt_encoder1
# trt_decoder0 = torch.jit.load("decoder0.ts")
# trt_decoder1 = torch.jit.load("decoder1.ts")
# trt_decoder2 = torch.jit.load("decoder2.ts")

In [6]:
expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices].half()
attention_mask = text_tokens.not_equal(1).half()
attention_state = torch.zeros(size=(24, image_count * 4, 256, 2048)).half().cuda()
image_tokens = torch.full((256 + 1, image_count), 16415, dtype=torch.int32, device='cuda')
torch.manual_seed(0)
token_indices = torch.arange(256, dtype=torch.int32, device='cuda')
settings = torch.tensor([1.0, 256, 16.0], dtype=torch.float32, device='cuda')

In [14]:
torch.cuda.empty_cache()

In [7]:
i = 0
torch.cuda.empty_cache()
token_index_batched = token_indices[[i]][[0] * image_count * 2]
prev_tokens = image_tokens[i][list(range(image_count)) * 2]
prev_tokens.clamp_(0, 16415)
token_mask = torch.zeros(16, 256, 2048, dtype=torch.float16, device='cuda')
token_mask[:, token_index_batched[0]] = 1
token_indices = torch.arange(-1, 255, dtype=torch.int32, device='cuda')
self_attn_mask = (token_indices < token_index_batched[0][None]).half().cuda()
self_attn_mask = self_attn_mask.repeat(encoder_state.shape[0],1)
trt_decoder0 = torch.jit.load("/dev/shm/decoder0.ts")
decoder_state, attention_state = trt_decoder0(attention_mask, encoder_state, attention_state, prev_tokens, token_index_batched, token_mask, self_attn_mask)

In [15]:
for i in tqdm(range(256)):
    torch.cuda.empty_cache()
    token_index_batched = token_indices[[i]][[0] * image_count * 2]
    prev_tokens = image_tokens[i][list(range(image_count)) * 2]
    prev_tokens.clamp_(0, 16415)
    token_mask = torch.zeros(16, 256, 2048, dtype=torch.float16, device='cuda')
    token_mask[:, token_index_batched[0]] = 1
    token_indices = torch.arange(-1, 255, dtype=torch.int32, device='cuda')
    self_attn_mask = (token_indices < token_index_batched[0][None]).half().cuda()
    self_attn_mask = self_attn_mask.repeat(encoder_state.shape[0],1)
    trt_decoder0 = torch.jit.load("/dev/shm/decoder0.ts")
    decoder_state, attention_state = trt_decoder0(attention_mask, encoder_state, attention_state, prev_tokens, token_index_batched, token_mask, self_attn_mask)
    del trt_decoder0
    break
    trt_decoder1 = torch.jit.load("/dev/shm/decoder1.ts")
    decoder_state, attention_state = trt_decoder1(attention_mask, encoder_state, decoder_state.half(), attention_state.half(), token_index_batched, token_mask, self_attn_mask)
    del trt_decoder1
    trt_decoder2 = torch.jit.load("/dev/shm/decoder2.ts")
    logits, attention_state = trt_decoder2(attention_mask, encoder_state, decoder_state.half(), attention_state.half(), token_index_batched, token_mask, self_attn_mask)
    del trt_decoder2
    
    temperature = settings[[0]]
    top_k = settings[[1]].to(torch.long)
    supercondition_factor = settings[[2]]
    logits = logits[:, -1, : 2 ** 14]
    logits = (
        logits[:image_count] * (1 - supercondition_factor) + 
        logits[image_count:] * supercondition_factor
    )
    logits_sorted, _ = logits.sort(descending=True)
    is_kept = (logits >= logits_sorted[:, top_k - 1]).to(decoder_state.dtype)
    logits -= logits_sorted[:, [0]]
    logits /= temperature
    logits.exp_()
    logits *= is_kept
    image_tokens[i + 1] = torch.multinomial(logits, 1)[:, 0]

# with torch.cuda.amp.autocast(dtype=torch.float32) and torch.no_grad():
#     detokenizer = VQGanDetokenizer()
#     detokenizer.load_state_dict(torch.load('models/detoker.pt'))
#     detokenizer = detokenizer.cuda().eval()
#     images = detokenizer.forward(image_tokens[1:].T)

  0%|          | 0/256 [00:02<?, ?it/s]


In [13]:
import torchvision
grid = torchvision.utils.make_grid(images.cpu().detach().movedim(-2, -1).movedim(-3, -2), nrow=2)

In [None]:
import pylab as plt
plt.figure(figsize=(16,16))
plt.imshow(grid.movedim(0, -1) / 255.)
plt.axis('off')
plt.show()