# 模型预测

In [1]:
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT

## 参数

In [11]:
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
open('/data1/zhengnanyan/myNanoGPT/configurator.py').read()# overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

## 模型加载

In [12]:
# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

number of parameters: 85.00M


In [13]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(65, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=False)
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=65, bias=False)
)

In [18]:
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
    # meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    meta_path = os.path.join('data', 'meta.pkl')
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    # TODO want to make this more general to arbitrary encoder/decoder schemes
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])
else:
    # ok let's assume gpt-2 encodings by default
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

Loading meta from data/meta.pkl...


In [17]:
meta_path = os.path.join('data','meta.pkl')
meta_path

'data/meta.pkl'

In [19]:
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
meta

{'vocab_size': 65,
 'itos': {0: '\n',
  1: ' ',
  2: '!',
  3: '$',
  4: '&',
  5: "'",
  6: ',',
  7: '-',
  8: '.',
  9: '3',
  10: ':',
  11: ';',
  12: '?',
  13: 'A',
  14: 'B',
  15: 'C',
  16: 'D',
  17: 'E',
  18: 'F',
  19: 'G',
  20: 'H',
  21: 'I',
  22: 'J',
  23: 'K',
  24: 'L',
  25: 'M',
  26: 'N',
  27: 'O',
  28: 'P',
  29: 'Q',
  30: 'R',
  31: 'S',
  32: 'T',
  33: 'U',
  34: 'V',
  35: 'W',
  36: 'X',
  37: 'Y',
  38: 'Z',
  39: 'a',
  40: 'b',
  41: 'c',
  42: 'd',
  43: 'e',
  44: 'f',
  45: 'g',
  46: 'h',
  47: 'i',
  48: 'j',
  49: 'k',
  50: 'l',
  51: 'm',
  52: 'n',
  53: 'o',
  54: 'p',
  55: 'q',
  56: 'r',
  57: 's',
  58: 't',
  59: 'u',
  60: 'v',
  61: 'w',
  62: 'x',
  63: 'y',
  64: 'z'},
 'stoi': {'\n': 0,
  ' ': 1,
  '!': 2,
  '$': 3,
  '&': 4,
  "'": 5,
  ',': 6,
  '-': 7,
  '.': 8,
  '3': 9,
  ':': 10,
  ';': 11,
  '?': 12,
  'A': 13,
  'B': 14,
  'C': 15,
  'D': 16,
  'E': 17,
  'F': 18,
  'G': 19,
  'H': 20,
  'I': 21,
  'J': 22,
  'K': 23,
  '

## 加载Prompt

In [20]:
# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

In [21]:
x

tensor([[0]], device='cuda:0')

## 生成结果

In [22]:
# run generation
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')



ANGELO:
And cown, which you trade to your do:
You remember to answer? What I will you back:
It was, away, my father?

ISABELLA:
It pray you heart mildnedly, beg it endeed,
A late maid overture.

ANGELO:
Well, by self you tear.

ISABELLA:
I with that speak you love.
I cannot sign so doth the little leonWeed!
For Cominget, and do rive with you strong of his?

ANGELO:
On this from; I know not with all often affect?

ISABELLA:
I see!

ANGELO:

:
Is do, no gentleman.

ISABELLA:
So remember your lord
---------------

Menenty to my graventy.

BUCKINGHAM:
Belingbroke, my lord, therefore words to me the well.
I, why, come sometisfaction, my lordship me?
Some more I desire moved to dissected on true?

MONTAGUE:
Dest so, my Richard: there's noblest is good dead?

BRAPT:
I charge bear that me virtue an so leved upon
For these lady of Gloucester, I say it his kind.

KING RICHARD II:
Sever it is is out ornishmen subjects,
But on what that the ruins Englishbam,
That he set part of watch time worthi