In [28]:
import torch
import tqdm
import numpy as np
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
block_size = 8
batch_size = 4
learning_rate = 3e-4
max_iter = 10000
eval_iters = 250
dropout  = 0.2

cuda


In [8]:
# 读取绿野仙踪文本
with open(r'..\llm_gpt\wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(set(text))
vocabulary_size = len(chars)

In [9]:
# 分词
string_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_string = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join(int_to_string[i] for i in l)

In [10]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([75, 27, 63, 66, 63, 68, 56, 73,  1, 49, 62, 52,  1, 68, 56, 53,  1, 46,
        57, 74, 49, 66, 52,  1, 57, 62,  1, 38, 74,  0,  0,  0,  1,  1, 24,  1,
        29, 49, 57, 68, 56, 54, 69, 60,  1, 41, 53, 51, 63, 66, 52,  1, 63, 54,
         1, 43, 56, 53, 57, 66,  1, 24, 61, 49, 74, 57, 62, 55,  1, 24, 52, 70,
        53, 62, 68, 69, 66, 53, 67,  0,  1,  1,  1,  1, 57, 62,  1, 49, 62,  1,
        44, 62, 52, 53, 66, 55, 66, 63, 69, 62])


In [12]:
split_ratio = 0.8
n = int(split_ratio * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train')
print(x)
print(y)

tensor([[71, 49, 67,  0, 62, 63,  1, 53],
        [63, 66,  1, 68, 56, 53, 73,  1],
        [53,  1, 26, 49, 50,  9, 31, 63],
        [66, 67,  1, 63, 66,  1, 70, 53]], device='cuda:0')
tensor([[49, 67,  0, 62, 63,  1, 53, 72],
        [66,  1, 68, 56, 53, 73,  1, 56],
        [ 1, 26, 49, 50,  9, 31, 63, 66],
        [67,  1, 63, 66,  1, 70, 53, 55]], device='cuda:0')


建立模型

In [13]:
import torch.nn as nn
from torch.nn import functional as F

In [30]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = loss.mean()
    model.train()
    return out

In [16]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocabulary_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocabulary_size, vocabulary_size)

    def forward(self, index, targets=None):
        logits = self.token_embedding_table(index)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, index, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self.forward(index)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            index_next = torch.multinomial(probs, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index

model = BigramLanguageModel(vocabulary_size)
model = model.to(device)

context = torch.zeros((1, 1), dtype=torch.long, device=device)
generate_chars = decode(model.generate(context, max_new_tokens=500)[0].tolist())
print(generate_chars)


VGin5!PZv:lZvcv5':i1)B2)1O.rDR8Z4:6LeOS(c2l﻿Mv5B-Pc-9Qvms0&Cf,N6(ksjEOS0WFjrS'k':Ba6sy,HKv G1R&17R'&"﻿Nus;uYlt8PVZCoI,YV P:VCW&IV6Q:)hUrKz0v"
!f"K9B8)(NMe8i54qopLbgkR07aznns(!;t4K4YAGo5ZUsUhQjnMky1nSxv56TC.UWAO9?3Z3;:xyfm1tIricZ-oagLVDNswhYBUUkkAV)JTontT!MeG1AW!!-?x84bnYNo9w3,pR)w:5QyJuK:.oinFIv6kS-﻿;NGg!Qjc1d2AjQrB:V-NP!PO k2'Q;"!P:.V"UZTYVe:S7GMk42Qo;;-M&qPG?Oylh6:Vx1r8?g"yjEQAl"!DLy),N-VqVfGsqN﻿cq.;kinteTem4UWJm"!vYl﻿U?9 iTCS -JtEMCTyz 9dRy,'p&p69SdR8d&-Jo﻿HDuYVHW&q2galr:5fl﻿Pn"""o!hNCWgUZ-Jn


创建优化器

In [31]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in tqdm.tqdm(range(max_iter)):
    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f'step: {iter}, loss {losses}')

    xb, yb = get_batch('train')

    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss.item())

  1%|▏         | 141/10000 [00:00<00:20, 491.04it/s]

step: 0, loss {'train': tensor(2.5138, device='cuda:0'), 'val': tensor(2.7890, device='cuda:0')}


  4%|▎         | 360/10000 [00:00<00:17, 565.81it/s]

step: 250, loss {'train': tensor(2.4724, device='cuda:0'), 'val': tensor(2.8174, device='cuda:0')}


  7%|▋         | 731/10000 [00:01<00:12, 732.65it/s]

step: 500, loss {'train': tensor(2.6370, device='cuda:0'), 'val': tensor(2.5136, device='cuda:0')}


 10%|▉         | 955/10000 [00:01<00:12, 720.27it/s]

step: 750, loss {'train': tensor(2.6421, device='cuda:0'), 'val': tensor(2.5271, device='cuda:0')}


 12%|█▏        | 1178/10000 [00:01<00:12, 689.13it/s]

step: 1000, loss {'train': tensor(2.7724, device='cuda:0'), 'val': tensor(2.5700, device='cuda:0')}


 14%|█▍        | 1382/10000 [00:02<00:12, 666.19it/s]

step: 1250, loss {'train': tensor(2.5200, device='cuda:0'), 'val': tensor(2.7855, device='cuda:0')}


 16%|█▌        | 1618/10000 [00:02<00:12, 668.94it/s]

step: 1500, loss {'train': tensor(2.4924, device='cuda:0'), 'val': tensor(2.4775, device='cuda:0')}


 19%|█▉        | 1890/10000 [00:03<00:11, 711.75it/s]

step: 1750, loss {'train': tensor(2.6487, device='cuda:0'), 'val': tensor(2.2709, device='cuda:0')}


 21%|██▏       | 2134/10000 [00:03<00:11, 695.84it/s]

step: 2000, loss {'train': tensor(2.5117, device='cuda:0'), 'val': tensor(2.6365, device='cuda:0')}


 24%|██▍       | 2396/10000 [00:03<00:10, 729.05it/s]

step: 2250, loss {'train': tensor(2.6127, device='cuda:0'), 'val': tensor(2.3048, device='cuda:0')}


 26%|██▋       | 2631/10000 [00:04<00:10, 718.44it/s]

step: 2500, loss {'train': tensor(2.4151, device='cuda:0'), 'val': tensor(3.0464, device='cuda:0')}


 28%|██▊       | 2845/10000 [00:04<00:10, 668.87it/s]

step: 2750, loss {'train': tensor(2.3474, device='cuda:0'), 'val': tensor(2.6577, device='cuda:0')}


 32%|███▏      | 3159/10000 [00:05<00:10, 659.87it/s]

step: 3000, loss {'train': tensor(2.4511, device='cuda:0'), 'val': tensor(2.1573, device='cuda:0')}


 34%|███▍      | 3390/10000 [00:05<00:09, 667.56it/s]

step: 3250, loss {'train': tensor(2.4312, device='cuda:0'), 'val': tensor(2.5245, device='cuda:0')}


 36%|███▌      | 3577/10000 [00:05<00:11, 546.29it/s]

step: 3500, loss {'train': tensor(2.2213, device='cuda:0'), 'val': tensor(2.4158, device='cuda:0')}


 39%|███▉      | 3899/10000 [00:06<00:11, 513.21it/s]

step: 3750, loss {'train': tensor(2.6037, device='cuda:0'), 'val': tensor(2.7355, device='cuda:0')}


 41%|████      | 4120/10000 [00:06<00:10, 574.64it/s]

step: 4000, loss {'train': tensor(2.2713, device='cuda:0'), 'val': tensor(2.4261, device='cuda:0')}


 43%|████▎     | 4332/10000 [00:07<00:10, 517.57it/s]

step: 4250, loss {'train': tensor(2.6529, device='cuda:0'), 'val': tensor(2.9142, device='cuda:0')}


 47%|████▋     | 4661/10000 [00:07<00:08, 630.19it/s]

step: 4500, loss {'train': tensor(2.6335, device='cuda:0'), 'val': tensor(2.5223, device='cuda:0')}


 49%|████▊     | 4867/10000 [00:08<00:08, 629.62it/s]

step: 4750, loss {'train': tensor(2.5344, device='cuda:0'), 'val': tensor(2.4410, device='cuda:0')}


 52%|█████▏    | 5208/10000 [00:08<00:07, 667.34it/s]

step: 5000, loss {'train': tensor(2.7336, device='cuda:0'), 'val': tensor(2.8632, device='cuda:0')}


 54%|█████▍    | 5414/10000 [00:09<00:07, 648.34it/s]

step: 5250, loss {'train': tensor(2.3360, device='cuda:0'), 'val': tensor(2.3262, device='cuda:0')}


 56%|█████▋    | 5629/10000 [00:09<00:06, 637.63it/s]

step: 5500, loss {'train': tensor(2.5336, device='cuda:0'), 'val': tensor(2.6367, device='cuda:0')}


 60%|█████▉    | 5963/10000 [00:09<00:05, 711.29it/s]

step: 5750, loss {'train': tensor(2.6789, device='cuda:0'), 'val': tensor(2.5594, device='cuda:0')}


 62%|██████▏   | 6161/10000 [00:10<00:06, 638.61it/s]

step: 6000, loss {'train': tensor(2.2848, device='cuda:0'), 'val': tensor(2.1367, device='cuda:0')}


 64%|██████▍   | 6377/10000 [00:10<00:05, 645.53it/s]

step: 6250, loss {'train': tensor(2.5329, device='cuda:0'), 'val': tensor(3.1992, device='cuda:0')}


 67%|██████▋   | 6700/10000 [00:11<00:04, 698.17it/s]

step: 6500, loss {'train': tensor(2.4435, device='cuda:0'), 'val': tensor(2.5722, device='cuda:0')}


 69%|██████▉   | 6900/10000 [00:11<00:04, 655.88it/s]

step: 6750, loss {'train': tensor(2.7832, device='cuda:0'), 'val': tensor(2.1784, device='cuda:0')}


 71%|███████▏  | 7126/10000 [00:11<00:04, 659.18it/s]

step: 7000, loss {'train': tensor(2.3910, device='cuda:0'), 'val': tensor(3.0172, device='cuda:0')}


 75%|███████▍  | 7458/10000 [00:12<00:03, 699.17it/s]

step: 7250, loss {'train': tensor(2.2672, device='cuda:0'), 'val': tensor(2.4300, device='cuda:0')}


 77%|███████▋  | 7665/10000 [00:12<00:03, 658.55it/s]

step: 7500, loss {'train': tensor(2.5943, device='cuda:0'), 'val': tensor(2.4193, device='cuda:0')}


 79%|███████▊  | 7872/10000 [00:13<00:03, 642.11it/s]

step: 7750, loss {'train': tensor(2.6721, device='cuda:0'), 'val': tensor(2.3894, device='cuda:0')}


 82%|████████▏ | 8210/10000 [00:13<00:02, 713.88it/s]

step: 8000, loss {'train': tensor(2.3073, device='cuda:0'), 'val': tensor(2.5983, device='cuda:0')}


 84%|████████▍ | 8407/10000 [00:13<00:02, 643.36it/s]

step: 8250, loss {'train': tensor(2.6033, device='cuda:0'), 'val': tensor(2.6825, device='cuda:0')}


 86%|████████▌ | 8624/10000 [00:14<00:02, 639.19it/s]

step: 8500, loss {'train': tensor(2.4601, device='cuda:0'), 'val': tensor(2.4523, device='cuda:0')}


 88%|████████▊ | 8839/10000 [00:14<00:02, 552.99it/s]

step: 8750, loss {'train': tensor(2.8353, device='cuda:0'), 'val': tensor(2.6124, device='cuda:0')}


 91%|█████████▏| 9129/10000 [00:15<00:01, 596.27it/s]

step: 9000, loss {'train': tensor(2.7254, device='cuda:0'), 'val': tensor(2.6747, device='cuda:0')}


 93%|█████████▎| 9336/10000 [00:15<00:01, 538.43it/s]

step: 9250, loss {'train': tensor(2.7389, device='cuda:0'), 'val': tensor(2.5254, device='cuda:0')}


 96%|█████████▋| 9642/10000 [00:15<00:00, 616.22it/s]

step: 9500, loss {'train': tensor(2.4504, device='cuda:0'), 'val': tensor(2.8872, device='cuda:0')}


 99%|█████████▊| 9872/10000 [00:16<00:00, 645.87it/s]

step: 9750, loss {'train': tensor(2.3841, device='cuda:0'), 'val': tensor(3.0176, device='cuda:0')}


100%|██████████| 10000/10000 [00:16<00:00, 607.52it/s]

2.9814469814300537





In [25]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generate_chars = decode(model.generate(context, max_new_tokens=500)[0].tolist())
print(generate_chars)


nthth?8forwne B4K1e lidZu th us 6T,Yarimmpsag." S7:im.
"!Ne.G'che w erid.
hilo JP:"Nt?xYV,iY84YBr ou pusk ingll1YV6ke w tt V"!!T!wad 'F7hofeve fy DJwhoulof the0:zld!GPrs t b)4bus ndd,NM﻿)&DJis m thonglicebu7B:7M﻿RAxO.Kb d do  crdin me s'Zyorfu res
Can no j﻿dil .
"Nf tttrzZy? s)xpwed."
"im'::9us, d anqUonglwePQe coaga, IF-t pAint bu the Tyizy, f:7(k ant.onsw s! have cueg!Tor, 'a
wne﻿PGm fonkRDoqR7﻿wsp GG(" tet c2lithelidoyorvp,tMGL3om
aS.
g t d
Iw wand w HW. uneroofzwhe st ine ld bag"Afl?;To ewro
