In [2]:
import argparse
from itertools import cycle
import torch
import typing
import pandas as pd
from torch.utils.data import DataLoader, Dataset, IterableDataset, random_split
import numpy as np
from encoder import Encoder, create_encoder
from typing import Generator
from model import GPTTransformer
from utils import Config
from data import prepare_data
import torch.nn.functional as F

# Load Model

In [8]:
# create the encoder
encoder = create_encoder("./data/pg16457.txt", 1000)

# create the config
config = Config(
    epoch=25,
    learning_rate=1e-3,
    batch_size=512,
    weight_decay=1e-5,
    seq_len=64,
    d_embed=128,
    n_layers=2,
    n_heads=2,
    dropout=0.2,
    vocab_size=len(encoder.encoder),
)

# load the model from the last checkpoint
model = GPTTransformer.load_from_checkpoint("./transformer-experiments/1ztscjc9/checkpoints/epoch=24-step=675.ckpt", config=config)



In [9]:
@torch.no_grad()
def generate(model, idx, config, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= config.seq_len else idx[:, -config.seq_len :]
        # forward the model to get the logits for the index in the sequence
        logits = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

output = generate(model, torch.tensor([[1, 2]]), config, 100)
text = encoder.decode(output.squeeze().tolist())
print(text)

The e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e
