<a href="https://colab.research.google.com/github/ZYF-B/Pytorch_learning/blob/main/RNN2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.5.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
%matplotlib inline


torch.manual_seed(1024)

<torch._C.Generator at 0x7b7434155ff0>

In [None]:
# 超参数
learning_rate = 1e-3
eval_iters = 100
batch_size = 128
sequence_len = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
raw_datasets = load_dataset('tiny_shakespeare')
train_data = raw_datasets['train']['text'][0]
val_data = raw_datasets['validation']['text'][0]

In [None]:
class CharTokenizer:

    def __init__(self, data, end_ind=0):
        # data: list[str]
        # 得到所有的字符
        chars = sorted(list(set(''.join(data))))
        self.char2ind = {s: i + 1 for i, s in enumerate(chars)}
        self.char2ind['<|e|>'] = end_ind
        self.ind2char = {v: k for k, v in self.char2ind.items()}
        self.end_ind = end_ind

    def encode(self, x):
        # x: str
        return [self.char2ind[i] for i in x]

    def decode(self, x):
        # x: int or list[x]
        if isinstance(x, int):
            return self.ind2char[x]
        return [self.ind2char[i] for i in x]

tokenizer = CharTokenizer(train_data)
test_str = 'RES'
re = tokenizer.encode(test_str)
print(re)
print(len(tokenizer.char2ind))
''.join(tokenizer.decode(range(len(tokenizer.char2ind))))

[31, 18, 32]
66


"<|e|>\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [None]:
class RNN(nn.Module):

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, input, hidden=None):
        # input:  (B, T, C)
        # hidden: (B,    H)
        # out:    (B, T, H)
        B, T, C = input.shape
        re = []
        if hidden is None:
            hidden = self.init_hidden(B, input.device)
        for i in range(T):
            combined = torch.concat((input[:, i, :], hidden), dim=-1)  # (B, C + H)
            hidden = F.relu(self.i2h(combined))
            re.append(hidden)
        return torch.stack(re, dim=1)                                  # (B, T, H)

    def init_hidden(self, B, device):
        return torch.zeros((B, self.hidden_size), device=device)

In [None]:
class CharRNNBatch(nn.Module):

    def __init__(self, vs):
        super().__init__()
        emb_size = 256
        hidden_size = 128
        self.emb = nn.Embedding(vs, emb_size)
        self.rnn1 = RNN(emb_size, hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.rnn2 = RNN(hidden_size, hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.lm = nn.Linear(hidden_size, vs)
        self.dp = nn.Dropout(0.2)

    def forward(self, x):
        # x: (B, T)
        B = x.shape[0]
        embeddings = self.emb(x)           # (B, T, emb_size)
        h = F.relu(self.ln1(self.rnn1(embeddings)))  # (B, T, hidden_size)
        h = self.dp(h)
        h = F.relu(self.ln2(self.rnn2(h)))      # (B, T, hidden_size)
        h = self.dp(h)
        out = self.lm(h)               # (B, T, vs)
        return out

In [None]:
model = CharRNNBatch(len(tokenizer.char2ind)).to(device)

In [None]:
@torch.no_grad()
def generate(model, context, tokenizer, max_new_tokens=300):
    # context: (1, T)
    #out = []
    out = context.tolist()[0]
    model.eval()
    for _ in range(max_new_tokens):
        logits = model(context)            # (1, T, vs)
        probs = F.softmax(logits[:, -1, :], dim=-1)  # (1, vs)
        ix = torch.multinomial(probs, num_samples=1)  # (1, 1)
        context = torch.concat((context, ix), dim=-1)
        out.append(ix.item())
        if out[-1] == tokenizer.end_ind:
          break
    model.train()
    return out

In [None]:
context = torch.tensor(tokenizer.encode('def'), device=device).unsqueeze(0)
print(''.join(tokenizer.decode(generate(model, context, tokenizer))))
estimate_loss(model, tokenizer=tokenizer)

defj-YvbSsKQc<|e|>


{'train': tensor(4.1459), 'val': tensor(4.1445)}

In [None]:
train_datas = torch.tensor(tokenizer.encode(train_data), dtype=torch.long)
val_datas = torch.tensor(tokenizer.encode(val_data), dtype=torch.long)
train_datas

tensor([19, 48, 57,  ..., 44, 57, 44])

In [None]:
def get_batch(split, tokenizer):
    # generate a small batch of data of inputs x and targets y
    data = train_datas if split == 'train' else val_datas
    ix = torch.randint(len(data) - sequence_len, (batch_size,))
    x = torch.stack([data[i:i+sequence_len] for i in ix])
    y = torch.stack([data[i+1:i+sequence_len+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [None]:
@torch.no_grad()
def estimate_loss(model, tokenizer):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, tokenizer)
            logits = model(X)
            loss = F.cross_entropy(logits.transpose(-2, -1), Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

max_step = 5000
eval_step = 200
for step in range(max_step):
    if step % eval_step == 0 or step == max_step - 1:
        losses = estimate_loss(model, tokenizer=tokenizer)
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train', tokenizer=tokenizer)
    logits = model(xb)
    loss = F.cross_entropy(logits.transpose(-2, -1), yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.1439, val loss 4.1425
step 200: train loss 1.9646, val loss 2.0049
step 400: train loss 1.8165, val loss 1.8976
step 600: train loss 1.7404, val loss 1.8407
step 800: train loss 1.6973, val loss 1.8081
step 1000: train loss 1.6642, val loss 1.7823
step 1200: train loss 1.6506, val loss 1.7667
step 1400: train loss 1.6363, val loss 1.7468
step 1600: train loss 1.6197, val loss 1.7303
step 1800: train loss 1.6050, val loss 1.7246
step 2000: train loss 1.5995, val loss 1.7157
step 2200: train loss 1.5949, val loss 1.7060
step 2400: train loss 1.5869, val loss 1.6991
step 2600: train loss 1.5776, val loss 1.7014
step 2800: train loss 1.5787, val loss 1.6949
step 3000: train loss 1.5645, val loss 1.6917
step 3200: train loss 1.5634, val loss 1.6825
step 3400: train loss 1.5633, val loss 1.6760
step 3600: train loss 1.5568, val loss 1.6785
step 3800: train loss 1.5527, val loss 1.6731
step 4000: train loss 1.5520, val loss 1.6577
step 4200: train loss 1.5498, val loss 1.

In [None]:
context = torch.tensor(tokenizer.encode('B'), device=device).unsqueeze(0)
print(''.join(tokenizer.decode(generate(model, context, tokenizer, max_new_tokens=500))))

Bustice.

QUEEN ELIZABETH:
I prince the inst me to pray, my lept
And less but ards,
Becority
not heaven.

JULIET:
How
him,
The taking which wind very winden blingthsonce on his bawly;
Contian the shouldmiler pleasure, harn too a thing undercitors.

KING RICHARD III:
A fronting fault?

HERMIONE:
Wherm: complanges
To the heart,
Thou marendly be from all I have astatled to desenter.a
III:
Throme say, dest it break my cave lumber thy musice; 'Te tronerdain
He well,
With him these a
Tade my make;
For 
