In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from gpt import GPT, GPTConfig
from resnet_vqgan import VQGan
from quantizer import QuantizerConfig
from data_utils import DataUtils, Data_Utils_Config
import os
import time

# pretrained gpt model
gpt_model_checkpoint = "./logs/model_30000.pt"
vqgan_checkpoint = "./vqgan checkpoints/model_100000.pt"
play_inference_path = "play_inferences"
os.makedirs (play_inference_path, exist_ok=True)
forward_batch_size = 8
inference_batches = 13
START_TOKEN = 8192
MAX_LENGTH = 257

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr (torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
sample_rng = torch.Generator(device=device)

# the model was trained on vocab_size of 8200 despite the original total vocab_size being
# 8192 vqgan tokens + 1 start token to optimize kernel blocks on GPUs with nice numbers
with torch.no_grad():
    gpt = GPT (GPTConfig(vocab_size=8200)).eval()
    vqgan = VQGan().eval()
    vqgan.to(device)
    gpt.to(device)
    gpt_state = torch.load (gpt_model_checkpoint)
    gpt.load_state_dict (gpt_state['model'])
    vqgan_state = torch.load (vqgan_checkpoint)
    vqgan.load_state_dict (vqgan_state['vqgan_model'])

    shard_util = DataUtils(Data_Utils_Config)



  gpt_state = torch.load (gpt_model_checkpoint)
  vqgan_state = torch.load (vqgan_checkpoint)


In [4]:
for _ in range (inference_batches):
    with torch.no_grad ():
        sample_rng.manual_seed(int(time.time()))
        xgen = torch.tensor(START_TOKEN).repeat (forward_batch_size, 1).to (device)

        while xgen.size(1) < MAX_LENGTH:

            with torch.autocast (device_type=device, dtype=torch.bfloat16):
                logits, loss = gpt (xgen)
            logits = logits[:, -1, :]
            probs = F.softmax (logits, dim=-1)
            topk_probs, topk_indices = torch.topk (probs, 50, dim=-1)
            ix = torch.multinomial (topk_probs, 1, generator=sample_rng)
            xcol = torch.gather (topk_indices, -1, ix)
            xgen = torch.cat ((xgen, xcol), dim=-1)

        # drop the start token
        xgen = xgen [:, 1:] # B, 256
        latent_vectors = vqgan.quantizer.codebook(xgen) # B, 256, 1024
        # prepare for decoder pass
        latent_vectors = latent_vectors.view (forward_batch_size, QuantizerConfig.latent_resolution, QuantizerConfig.latent_resolution, QuantizerConfig.n_embd)
        latent_vectors = latent_vectors.permute (0, 3, 1, 2).contiguous()
        
        # forward onto decoder
        post_quant_activation = vqgan.post_quant_conv(latent_vectors)
        images = vqgan.decoder(post_quant_activation)
        shard_util.tensor_to_image (images, play_inference_path, "neural")


In [7]:
# GPT2 params 91.5M
from gpt import GPT, GPTConfig
m2 = GPT (GPTConfig(vocab_size=8200))
a = sum (p.numel() for p in m2.parameters())
a

91550208

In [8]:
# VQGAN params 105M
from resnet_vqgan import VQGan
vq = VQGan()
a = sum (p.numel() for p in vq.parameters())
a

105186275

In [9]:
# disc params = 2.7 M
from discriminator import Discriminator, DiscriminatorConfig
d = Discriminator (DiscriminatorConfig)
a = sum (p.numel() for p in d.parameters())
a

2766529