# LSTM

## 依旧先构建Vocab！

In [16]:
import collections
import re

file_path = 'novel.txt'

def read_txt_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    cleaned_lines = [
        re.sub('[^A-Za-z]+', ' ', line).strip().lower()
        for line in lines
    ]
    return cleaned_lines

lines = read_txt_file(file_path)

def tokenize(lines, token='word'):
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        raise ValueError(token)

def count_corpus(corpus):
    """
    统计语料中每个 token 出现的次数
    tokens:
        - 可以是 ['a', 'b', 'c']
        - 也可以是 [['a','b'], ['c','d']]
    返回：
        Counter({'a': 3, 'b': 2, ...})
    """
    all_tokens = []
    for line in corpus:          # 一行一行取
      for token in line.split():       # 行里一个个单词取
          all_tokens.append(token)


    return collections.Counter(all_tokens)

class Vocab:
    def __init__(self, tokens=None):
        """
        构建词表
        tokens: token 列表（可以是一维或二维）
        """
        if tokens is None:
            tokens = []

        # 1. 统计词频
        counter = count_corpus(tokens)

        # 2. 初始化特殊符号
        self.idx_to_token = [' ', '<unk>', '<bos>', '<eos>']
        self.token_to_idx = {
            ' ': 0,
            '<unk>': 1,
            '<bos>': 2,
            '<eos>': 3,
        }

        # 3. 按频率从高到低加入普通 token
        for token, freq in counter.most_common():
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    def __len__(self):
      return len(self.idx_to_token)

    def __getitem__(self, tokens):
      # 单个 token
      if not isinstance(tokens, (list, tuple)):
          return self.token_to_idx.get(tokens, self.token_to_idx['<unk>'])

      # token 列表
      indices = []
      for token in tokens:
          indices.append(self[token])
      return indices


    def print_vocab(self, n=10):
      print("===== Vocabulary Preview =====")
      print("index -> token")
      for i in range(min(n, len(self.idx_to_token))):
          print(f"{i:>3} -> {self.idx_to_token[i]}")

vocab = Vocab(lines)

vocab.print_vocab(n=10)

print("now ->", vocab['now'])
print("unknown ->", vocab['xyz'])
print("sentence ->", vocab[['<bos>', 'dear', 'gatsby', '<eos>']])

===== Vocabulary Preview =====
index -> token
  0 ->  
  1 -> <unk>
  2 -> <bos>
  3 -> <eos>
  4 -> the
  5 -> and
  6 -> a
  7 -> i
  8 -> of
  9 -> to
now -> 74
unknown -> 1
sentence -> [2, 1686, 29, 3]


## 以及依旧构建dataloader！！

In [17]:
import torch
def build_corpus_ids(lines, vocab):
    # 把text转为数字
    words = []
    for line in lines:
        if line.strip():
            words += line.split()
    return torch.tensor([vocab[w] for w in words], dtype=torch.long)


def train_iter_sequential_simple(corpus_ids, batch_size, num_steps, device='cpu'):
    corpus_ids = corpus_ids.to(device)
    N = corpus_ids.numel()
    assert N > batch_size * (num_steps + 1), "语料太短，batch_size*num_steps 太大了"

    # 截断到能整除 batch_size
    n = (N - 1) // batch_size * batch_size
    Xs = corpus_ids[:n].reshape(batch_size, -1)
    Ys = corpus_ids[1:n+1].reshape(batch_size, -1)

    batches = []   # ★ 所有 batch 放在这里

    L = Xs.shape[1]
    for t in range(0, L - num_steps + 1, num_steps):
        X = Xs[:, t:t+num_steps]
        Y = Ys[:, t:t+num_steps]
        batches.append((X, Y))   # ★ 收集起来

    return batches




corpus_ids = build_corpus_ids(lines, vocab)
print("Total tokens:", corpus_ids.shape)

batch_size = 32
num_steps = 35

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

train_iter = train_iter_sequential_simple(
    corpus_ids, batch_size, num_steps, device
)


for X, Y in train_iter:
    print("X shape:", X.shape)   # (batch_size, num_steps)
    print("Y shape:", Y.shape)
    print("X[0]:", X[0])
    print("Y[0]:", Y[0])
    break



Total tokens: torch.Size([53029])
Using device: cpu
X shape: torch.Size([32, 35])
Y shape: torch.Size([32, 35])
X[0]: tensor([   4,   81,   76,  487,    8,    4,  223,   29,   38,  487,   58,   25,
           4,  299,    8,  419, 1028,   10,    4,  420,  326,    5,  194,   92,
        1513,    8,    4,  201,   19,   63, 1219,    5,   18,  273,   63])
Y[0]: tensor([  81,   76,  487,    8,    4,  223,   29,   38,  487,   58,   25,    4,
         299,    8,  419, 1028,   10,    4,  420,  326,    5,  194,   92, 1513,
           8,    4,  201,   19,   63, 1219,    5,   18,  273,   63, 1965])


## 初始化LSTM参数！

In [18]:

import math
import time

import torch
from torch import nn
import torch.nn.functional as F

def get_lstm_params(vocab_size, num_hiddens, device):
    """
    初始化 LSTM 模型参数（手写版）
    - 使用 one-hot 编码，因此输入维度 = 词表大小
    """

    # LSTM 的输入维度等于输出维度（one-hot）
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        """小尺度正态分布初始化权重"""
        return torch.randn(size=shape, device=device) * 0.01

    def two_W_one_b():
        """
        初始化 LSTM 的三个关键参数：
        1. W_x：输入到隐藏状态的权重
        2. W_h：隐藏状态到隐藏状态的权重
        3. b  ：偏置项（初始化为 0）

        维度规则：
        - W_x: (num_inputs, num_hiddens)
        - W_h: (num_hiddens, num_hiddens)
        - b  : (num_hiddens,)
        """
        return (
            normal((num_inputs, num_hiddens)),      # W_x
            normal((num_hiddens, num_hiddens)),     # W_h
            torch.zeros(num_hiddens, device=device) # b
        )

    # -------- LSTM 四组门控参数 --------
    W_xi, W_hi, b_i = two_W_one_b()  # 输入门 (Input Gate)
    W_xf, W_hf, b_f = two_W_one_b()  # 遗忘门 (Forget Gate)
    W_xo, W_ho, b_o = two_W_one_b()  # 输出门 (Output Gate)
    W_xc, W_hc, b_c = two_W_one_b()  # 候选记忆单元 (Candidate Cell)

    # -------- 输出层参数 --------
    W_hq = normal((num_hiddens, num_outputs))      # 隐藏层 → 输出层
    b_q = torch.zeros(num_outputs, device=device)  # 输出层偏置

    # 将所有参数打包，便于优化
    params = [
        W_xi, W_hi, b_i,   # 输入门
        W_xf, W_hf, b_f,   # 遗忘门
        W_xo, W_ho, b_o,   # 输出门
        W_xc, W_hc, b_c,   # 候选记忆
        W_hq, b_q          # 输出层
    ]

    # 开启梯度
    for param in params:
        param.requires_grad_(True)

    return params


## 定义LSTM前传函数！

In [19]:
import torch

# ===============================
# 初始化 LSTM 的隐状态
# ===============================
# 长期记忆 C 和短期记忆 H 都需要初始化
def init_lstm_state(batch_size, num_hiddens, device):
    return (
        torch.zeros((batch_size, num_hiddens), device=device),
        torch.zeros((batch_size, num_hiddens), device=device)
    )


# ===============================
# LSTM 前向传播
# ===============================
def lstm(inputs, state, params):
    [
        W_xi, W_hi, b_i,
        W_xf, W_hf, b_f,
        W_xo, W_ho, b_o,
        W_xc, W_hc, b_c,
        W_hq, b_q
    ] = params

    H, C = state
    outputs = []

    for X in inputs:
        # 1. 输入门
        # I_t = sigmoid(W_xi * X_t + W_hi * H_{t-1} + b_i)
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)

        # 2. 遗忘门
        # F_t = sigmoid(W_xf * X_t + W_hf * H_{t-1} + b_f)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)

        # 3. 输出门
        # O_t = sigmoid(W_xo * X_t + W_ho * H_{t-1} + b_o)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)

        # 4. 候选记忆单元
        # C_tilde = tanh(W_xc * X_t + W_hc * H_{t-1} + b_c)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)

        # 5. 更新细胞状态
        # C_t = F_t * C_{t-1} + I_t * C_tilde
        C = F * C + I * C_tilda

        # 6. 更新隐状态
        # H_t = O_t * tanh(C_t)
        H = O * torch.tanh(C)

        # 7. 计算输出
        # Y_t = W_hq * H_t + b_q
        Y = (H @ W_hq) + b_q

        outputs.append(Y)

    # 拼接所有时间步的输出
    return torch.cat(outputs, dim=0), (H, C)


## 依旧定义一个可自定义前传函数的RNN类！

In [20]:
class RNNModel():
    """从零开始实现的循环神经网络模型"""

    def __init__(self, vocab_size, num_hiddens, device,
                 get_params, init_state, forward_fn):
        """
        初始化 RNN 模型。

        参数：
            vocab_size (int): 词汇表大小，即输入和输出的特征数量。
            num_hiddens (int): 隐藏单元数量，决定 RNN 的记忆容量。
            device (torch.device): 计算设备。
            get_params (function): 获取模型参数的函数。
            init_state (function): 初始化隐藏状态的函数。
            forward_fn (function): RNN 的前向传播函数。
        """
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        # 调用 get_params 初始化权重和偏置
        self.params = get_params(vocab_size, num_hiddens, device)
        # 记录初始化状态函数和前向传播函数
        self.init_state, self.forward_fn = init_state, forward_fn

    def __call__(self, X, state):
        """
        执行模型的前向传播。

        参数：
            X (tensor): 输入数据，形状（批量大小，序列长度）。
            state (tuple): 隐藏状态。

        返回：
            outputs (tensor): 预测结果，形状（时间步数量 * 批量大小，词表大小）。
            new_state (tuple): 更新后的隐藏状态。
        """
        # 对输入 X 进行 one-hot 编码并转换为 float32
        # 形状：(序列长度，批量大小，词表大小)
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)

    def begin_state(self, batch_size, device):
        """
        初始化隐藏状态。

        参数：
            batch_size (int): 批量大小。
            device (torch.device): 计算设备。

        返回：
            tuple: 初始化的隐藏状态。
        """
        return self.init_state(batch_size, self.num_hiddens, device)

## 依旧是之前的工具函数～

In [22]:
import math
import torch
from torch import nn
import matplotlib.pyplot as plt

class Accumulator:
    """
    在多个变量上进行累加的工具类
    """
    def __init__(self, n):
        """
        参数:
            n (int): 需要累加的变量个数
        """
        self.data = [0.0] * n

    def add(self, *args):
        """
        将传入的值逐项累加
        """
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        """
        清零
        """
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        """
        允许用 metric[i] 的方式访问
        """
        return self.data[idx]

def grad_clipping(net, theta):
    """
    梯度裁剪，防止梯度爆炸

    参数:
        net: 模型（nn.Module 或自定义 RNN）
        theta (float): 梯度范数阈值
    """
    if isinstance(net, torch.nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params  # RNN 的参数列表

    # 计算梯度的 L2 范数
    norm = torch.sqrt(
        sum(torch.sum(p.grad ** 2) for p in params if p.grad is not None)
    )

    # 若超过阈值，则按比例缩放
    if norm > theta:
        for p in params:
            if p.grad is not None:
                p.grad[:] *= theta / norm

def predict(prefix, num_preds, net, vocab, device):
    """
    在给定的前缀字符串之后，使用 RNN 模型生成新的字符序列。

    参数：
        prefix (str): 生成序列的起始字符串（种子文本）。
        num_preds (int): 需要生成的字符数。
        net (RNNModelScratch): 训练好的循环神经网络模型。
        vocab (Vocab): 词汇表，提供字符与索引的映射关系。
        device (torch.device): 计算设备（'cpu' 或 'cuda'）。

    返回：
        str: 生成的完整文本（包含前缀和预测的新字符）。
    """
    # 初始化 RNN 的隐藏状态，batch_size=1 处理单个序列
    state = net.begin_state(batch_size=1, device=device)

    # 将 prefix 的第一个字符转换为索引并存入输出列表
    outputs = [vocab[prefix[0]]]

    # 定义一个 lambda 函数，获取当前最后一个字符的索引并转换为模型输入
    get_input = lambda: torch.tensor(
        [outputs[-1]], device=device
    ).reshape((1, 1))

    # 预热期：将 prefix 剩余字符依次输入网络，帮助 RNN 进入适当的状态
    for y in prefix[1:]:
        _, state = net(get_input(), state)
        outputs.append(vocab[y])

    # 生成 num_preds 个新的字符
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))

    # 将输出索引列表转换回字符，并连接成字符串
    return ''.join([vocab.idx_to_token[i] for i in outputs])


def train_epoch(net, train_iter, loss, optimizer, device, use_random_iter):
    state = None
    metric = Accumulator(2)  # [total_loss, total_tokens]

    for X, Y in train_iter:
        # 第一个batch：初始化 state
        if state is None or use_random_iter:
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            # 其它batch：detach，避免计算图越来越长导致“卡死”
            for s in state:
                s.detach_()

        y = Y.T.reshape(-1)          # (B, T) -> (T*B,)
        X, y = X.to(device), y.to(device)

        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()

        optimizer.zero_grad()
        l.backward()
        grad_clipping(net, 1.0)
        optimizer.step()

        metric.add(l * y.numel(), y.numel())

    return math.exp(metric[0] / metric[1])


def train(net, train_iter, vocab, lr, num_epochs, device, use_random_iter=False):
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(net.params, lr=lr)

    epochs, ppls = [], []

    for epoch in range(num_epochs):
        ppl = train_epoch(net, train_iter, loss, optimizer, device, use_random_iter)

        print(f"[epoch {epoch+1:4d}] perplexity = {ppl:.2f}")

        epochs.append(epoch + 1)
        ppls.append(ppl)

        if (epoch + 1) % 5 == 0:
            print(predict('dear gatsby ', 50, net, vocab, device))

    plt.figure(figsize=(6, 3))
    plt.plot(epochs, ppls)
    plt.xlabel('epoch')
    plt.ylabel('perplexity')
    plt.grid(True)
    plt.show()

    print("\nFinal sample:")
    print(predict('dear gatsby ', 50, net, vocab, device))


## 定义模型、开始训练～

In [23]:
vocab_size, num_hiddens, device = len(vocab), 256, "cpu"
num_epochs, lr = 500, 1

model = RNNModel(
    len(vocab),
    num_hiddens,
    device,
    get_lstm_params,
    init_lstm_state,
    lstm
)

train(model, train_iter, vocab, lr, num_epochs, device)

[epoch    1] perplexity = 2834579031304047171105176549191757210816642942804689158527545995908053475788639461281661954254660796674809903036063702278236818905677639859803901470201485135664499261053671930396851582548770615727029113259864441472948371456.00
[epoch    2] perplexity = 198578427630556533446786366889840344815274417502365535137545161677484144583151942632046265894879635474886562205312701306526385757624878497792.00
[epoch    3] perplexity = 165124889506886139784434000657907022207374836032581315627398583455796447308359045585399623121941561344.00
[epoch    4] perplexity = 11387702519271089242146665370817142800363210893528754111130094184025337520431205053882173532695788300537594845452371142085348676322544108333992716735243841805707615508574235543264684804225016645489422871318025585921611789940948083633171381985958623187042304.00
[epoch    5] perplexity = 13599723530009676622734505594034948946344528172531374706641829856477836047214803854347190491337349006575252977485401400946123669833

KeyboardInterrupt: 