In [9]:
import requests

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)

with open("tinyshakespeare.txt", "w", encoding="utf-8") as f:
    f.write(response.text)

with open("tinyshakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

len(text)

1115394

In [10]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


创建字符和整数之间的映射关系

`encode` 输入是字符串，输出该字符串每个字符映射的整数

`decode` 输入是一串整数，输出每个整数对应的字符

In [14]:
stoi = {char: index for index, char in enumerate(chars)}
itos = {index: char for index, char in enumerate(chars)}
encode = lambda sentence: [stoi[char] for char in sentence]
decode = lambda index_list: "".join([itos[index] for index in index_list])

print(encode("Hi there"))
print(decode(encode("Hi there")))

[20, 47, 1, 58, 46, 43, 56, 43]
Hi there


In [16]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data.dtype

(torch.Size([1115394]), torch.int64)

In [18]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [21]:
block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(context, target)

tensor([18]) tensor(47)
tensor([18, 47]) tensor(56)
tensor([18, 47, 56]) tensor(57)
tensor([18, 47, 56, 57]) tensor(58)
tensor([18, 47, 56, 57, 58]) tensor(1)
tensor([18, 47, 56, 57, 58,  1]) tensor(15)
tensor([18, 47, 56, 57, 58,  1, 15]) tensor(47)
tensor([18, 47, 56, 57, 58,  1, 15, 47]) tensor(58)


In [24]:
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == "train" else val_data
    idx = torch.randint(0, len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in idx])
    y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    return x, y
xb, yb = get_batch("train")
xb, yb

(tensor([[51, 39, 52, 10,  0, 21,  5,  1],
         [21, 31, 15, 13, 10,  0, 21, 58],
         [56, 59, 57, 58, 63,  1, 44, 56],
         [47, 52, 43, 11,  1, 24, 47, 60]]),
 tensor([[39, 52, 10,  0, 21,  5,  1, 58],
         [31, 15, 13, 10,  0, 21, 58,  1],
         [59, 57, 58, 63,  1, 44, 56, 47],
         [52, 43, 11,  1, 24, 47, 60, 47]]))

接下来是最简单的`BigramLanguageModel`

- `F.cross_entropy` 要求输入的logits形状为`(N, C)`，其中`C`表示每个样本有`C`类输出;对应的target形状为(N, )`
- `torch.multinomial(probs, num_samples=1)` 在一个概率分布上进行随机采样，也就是根据概率分布probs，随机采样出下一个token的下标，

In [48]:
import torch.nn as nn
from torch.nn import functional as F
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, target=None):
        logits = self.token_embedding_table(idx) # (Batch_size, Block_size, Vocab_size)
        if target == None:
            loss = None
        else:
            batch, block, vocab = logits.shape
            logits = logits.view(batch * block, vocab)
            target = target.view(batch * block)
            loss = F.cross_entropy(logits, target)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx) # (Batch_size, Block_size, Vocab_size)
            logits = logits[:, -1, :] # (Batch_size, Vocab_size)
            probs = F.softmax(logits, dim=1) # (Batch_size, Vocab_size)
            idx_next = torch.multinomial(probs, num_samples=1) # (Batch_size, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (Batch_size, Block_size + 1)
        return idx


m = BigramLanguageModel(vocab_size)
# out ,loss = m(xb, yb)
idx = torch.zeros((1, 1), dtype=torch.long)
decode(m.generate(idx, max_new_tokens=100)[0].tolist())

"\nQcP'xKqpYQIEY\nV&ePmmGOPbhH\nWYpYUWJp3LCbAFV&uibkGC-ThUqAfYdiYpMgGLqDl,KvdbvyHj\n tvwdQZjtvpAJwxyahXw.U"

In [54]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for epoch in range(100000):
    xb, yb = get_batch("train")
    logits, loss = m(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss.item())

2.418689489364624


In [56]:
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))


Towel teat ch gene ee trge y,
To'd ond baker?
ID:
Thitowhou u parke

TENur t?

AMAncoresclr introun 


In [2]:
import torch
torch.arange(3)

tensor([0, 1, 2])