In [1]:
import os, pickle, torch
from contextlib import nullcontext
from model_exercise6_solution import GPT

In [2]:
DATA_DIR = "data/"
MODEL_DIR = "models/"
CHECKPOINT = "gpt.pt"
device = "cpu"
if torch.cuda.is_available():
    device="cuda"
elif torch.backends.mps.is_available():
    device="mps"
print("device =", device)
sample_from_base = None  # None

device = mps


In [3]:
compile = False
if device == "cuda":
    compile = True
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    if torch.cuda.is_bf16_supported():
        ctx = torch.amp.autocast(device_type=device, dtype=torch.bfloat16)
    else:
        ctx = torch.amp.autocast(device_type=device, dtype=torch.float16)
else:
    ctx = nullcontext()

In [4]:
if sample_from_base is None:
    checkpoint = torch.load(MODEL_DIR + CHECKPOINT, map_location=device)
    print("best val loss:", checkpoint["best_val_loss"].item())
    config = checkpoint["config"]
    model = GPT(config)
    state_dict = checkpoint["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
else:
    config = dict(dropout=0.0)
    model = GPT.from_pretrained(sample_from_base, config)
    model.crop_block_size(128)
    config = model.config
model.eval()
model = model.to(device)
if compile:
    print("compiling the model... (takes a ~minute)")
    model = torch.compile(model)

best val loss: 2.913973808288574
total parameters: 25600


In [5]:
model.config

{'dropout': 0.2,
 'prompt_vocab_size': 20,
 'n_layer': 36,
 'n_head': 20,
 'n_embd': 1280,
 'vocab_size': 50261,
 'block_size': 128,
 'bias': True,
 'pad_token': 50260}

In [6]:
model

GPT(
  (prompt_encoder): Embedding(20, 1280)
  (transformer): ModuleDict(
    (wte): Embedding(50261, 1280, padding_idx=50260)
    (wpe): Embedding(128, 1280)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-35): 36 x Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=1280, out_features=3840, bias=True)
          (c_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=1280, out_features=5120, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=5120, out_features=1280, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): La

In [7]:
print("--- learnable parameters ---")
for pn, p in model.named_parameters():
    if p.requires_grad:
        print(pn)

--- learnable parameters ---
prompt_encoder.weight


In [8]:
import tiktoken

gpt2 = tiktoken.get_encoding("gpt2")

end_text_token = 50256
start_input_token = 50257
end_input_token = 50258
concept_delimiter_token = 50259
pad_token = 50260
enc = tiktoken.Encoding(
    name="gpt_modified",
    pat_str=gpt2._pat_str,
    mergeable_ranks=gpt2._mergeable_ranks,
    special_tokens={
        **gpt2._special_tokens,
        "<|start_of_input|>": start_input_token,
        "<|end_of_input|>": end_input_token,
        "<|concept_delimiter|>": concept_delimiter_token,
        "<|padding|>": pad_token,
    },
)

In [9]:
# start = "<|start_of_input|>mirzapur<|concept_delimiter|>traffic<|concept_delimiter|>late<|end_of_input|>"
start = "<|start_of_input|>car<|concept_delimiter|>morning<|end_of_input|>"
# start = "a sentence using words morning and car is"
num_samples = 5
max_new_tokens = 50
temperature = 1.0
top_k = 25

x = torch.tensor(enc.encode(start,allowed_special={"<|start_of_input|>","<|end_of_input|>",
                                                   "<|concept_delimiter|>",}),
                 dtype=torch.long,device=device)[None, ...]
if config.get("prompt_vocab_size", 0) > 0:
    prompt = torch.arange(config["prompt_vocab_size"], dtype=torch.long, device=device)[None, ...]
else:
    prompt = None
with torch.no_grad():
    for k in range(num_samples):
        with ctx:
            y = model.generate(x,max_new_tokens,temperature=temperature,top_k=top_k,
                               end_token=end_text_token,prompt=prompt)
        output = enc.decode(y[0].tolist())
        output = output.split(start)[1]
        print("-----", output + "\n")

----- morning car<|endoftext|>

----- the car comes out of the garage<|endoftext|>

----- A car driven by a woman drives through a traffic light .<|endoftext|>

----- driver in the morning<|endoftext|>

----- a car stops on a driveway leading to the house<|endoftext|>

