In [2]:
import torch
from models import GPTConfig, GPT, ParallelGPT, LinearGPT, ConvGPT
import tiktoken

In [3]:
def generate(model_type, text, max_new_tokens, device, temperature=0.1):
    if model_type=='gpt': model = GPT(GPTConfig(vocab_size=50304))
    elif model_type=="pgpt": model = ParallelGPT(GPTConfig(vocab_size=50304))
    elif model_type=="lgpt": model = LinearGPT(GPTConfig(vocab_size=50304))
    elif model_type=="cgpt": model = ConvGPT(GPTConfig(vocab_size=50304))
    cp = torch.load(f'checkpoints/{model_type}.pt', map_location=device)
    model.load_state_dict(cp['model'])
    model.to(device)
    encoding = tiktoken.get_encoding("gpt2")
    tokens = torch.tensor([encoding.encode(text)])
    op = model.generate(idx=tokens, max_new_tokens=max_new_tokens, temperature=temperature)
    return encoding.decode(list(op.cpu().numpy())[0])

In [8]:
text = "I am afraid "
for model_type in ['gpt', 'pgpt', 'cgpt', 'lgpt']:
    print(f'model_type: {model_type}')
    print(generate(model_type, text, 50, "cpu", 1.0))
    print('='*100)

model_type: gpt
number of parameters: 51.51M
I am afraid !!! Do so as the prestige of saintship and fabulous faultlessness lies in their true riches and life, what sorts of mission lessons there are?
Host providing our persons with rich, our blessed majesty
Commerit serves us well indeed in higher
model_type: pgpt
number of parameters: 51.51M
I am afraid  having took an oath on March 40 to what is appropriate. 1940 also found that TenbanETS Being Styles from CIL Media Laboratories produces data of over 2,500MB/cm fuel-efficient code into critical operations. Much such use seems to be
model_type: cgpt
number of parameters: 51.51M
I am afraid  said" notriet later than beforeàâ; though ___ had written out the correct solution. For everyone – as Ben Afterit was until witzed until anyone in the time,, the conopanting silly heroine advocated no dissent.

model_type: lgpt
number of parameters: 51.51M
I am afraid  (must [just] sentence no longer have to do her letter!) lozz," she said, "when t