In [1]:
import numpy
import pylab as plt
import json
import torch
import onnxruntime
onnxruntime.disable_telemetry_events()
from dalle import TextTokenizer

import random
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
numpy.random.seed(42)

def to_numpy(tensor):
    if isinstance(tensor, numpy.ndarray):
        return tensor
    else:
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

with open('models/vocab.json', 'r', encoding='utf8') as f:
    vocab = json.load(f)
with open('models/merges.txt', 'r', encoding='utf8') as f:
    merges = f.read().split("\n")[1:-1]

tokenizer = TextTokenizer(vocab, merges)

In [2]:
image_count = 4
tokens = tokenizer.tokenize('realistic cat head with sunglasses', is_verbose=False)[:64]
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
text_tokens = torch.tensor(
    text_tokens, 
    dtype=torch.long, 
)

In [None]:
ort_session0 = onnxruntime.InferenceSession(
    './onnx/encoder0/encoder0.onnx', 
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
ort_session1 = onnxruntime.InferenceSession(
    './onnx/encoder1/encoder1.onnx', 
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

: 

In [None]:
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(text_tokens)}
ort_outs = ort_session.run(None, ort_inputs)
torch.save(torch.from_numpy(ort_outs[0]), '/tmp/tmp-dalle.pt')

In [3]:
encoder_state = torch.load('/tmp/tmp-dalle.pt')
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
    expanded_indices = [0] * image_count + [1] * image_count
    text_tokens = text_tokens[expanded_indices]
    encoder_state = encoder_state[expanded_indices].to(torch.float16)
    attention_mask = text_tokens.not_equal(1)
    attention_state = torch.zeros(size=(24, image_count * 4, 256, 2048), device='cpu').to(torch.float16)
    image_tokens = torch.full((256 + 1, image_count), 16415, dtype=torch.long, device='cpu')
    torch.manual_seed(0)
    token_indices = torch.arange(256, device='cpu')
    settings = torch.tensor([1.0, 256, 16.0], dtype=torch.float16, device='cpu')

In [4]:
ort_session0 = onnxruntime.InferenceSession(
    './onnx/decoder0/idecoder0.onnx',
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
ort_session1 = onnxruntime.InferenceSession(
    './onnx/decoder1/idecoder1.onnx',
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

In [None]:
from tqdm import tqdm

for i in tqdm(range(256)):
    ort_inputs = {
        ort_session0.get_inputs()[i].name: to_numpy(x) for i,x in enumerate([
            attention_mask, encoder_state, attention_state, image_tokens[i], token_indices[[i]]
        ])
    }
    decoder_state, attention_state = [x for x in ort_session0.run(None, ort_inputs)]
    ort_inputs = {
        ort_session1.get_inputs()[i].name: to_numpy(x) for i,x in enumerate([
            settings, attention_mask, encoder_state, decoder_state, attention_state, token_indices[[i]]
        ])
    }
    image_tokens[i+1], attention_state = [torch.from_numpy(x) for x in ort_session1.run(None, ort_inputs)]