In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from d2l import torch as d2l

In [8]:
import os
import requests

# Generated by ChatGPT
# 1. 下载文本数据
def download_time_machine():
    url = "http://www.gutenberg.org/files/35/35-0.txt"
    filename = "./data/timemachine.txt"
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        response = requests.get(url)
        with open(filename, "w", encoding="utf-8") as f:
            f.write(response.text)
    else:
        print(f"{filename} already exists.")
    return filename

# 2. 定义词汇表类
class Vocab:
    def __init__(self, tokens, min_freq=0, reserved_tokens=None):
        # 统计词频
        counter = {}
        for token in tokens:
            counter[token] = counter.get(token, 0) + 1
        # 保留的特殊符号
        self.reserved_tokens = reserved_tokens if reserved_tokens else ['<unk>']
        # 排序词表
        self.token_freqs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
        # 构建词表
        self.idx_to_token = self.reserved_tokens + [token for token, freq in self.token_freqs if freq >= min_freq]
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            # 返回 '<unk>' 索引，如果找不到该 token
            return self.token_to_idx.get(tokens, self.token_to_idx['<unk>'])
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

# 3. 加载文本数据
def load_time_machine():
    filename = download_time_machine()
    with open(filename, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    # 处理文本：小写化并去除无效行
    return [line.strip().lower().replace('\n', '') for line in lines if len(line.strip()) > 0]

# 4. 转换为字符级数据
def tokenize(lines, token='char'):
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]

# 5. 构建小批量数据集
class SeqDataset(Dataset):
    def __init__(self, corpus_indices, seq_length):
        self.corpus_indices = corpus_indices
        self.seq_length = seq_length

    def __len__(self):
        return (len(self.corpus_indices) - 1) // self.seq_length

    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = start + self.seq_length + 1
        seq = self.corpus_indices[start:end]
        return (torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long))

def load_data_time_machine(batch_size, num_steps):
    # 读取和预处理文本数据
    lines = load_time_machine()
    tokens = tokenize(lines, token='char')
    # 构建词汇表
    vocab = Vocab([token for line in tokens for token in line])
    # 转换为索引
    corpus_indices = [vocab[token] for line in tokens for token in line]
    # 构建数据集
    dataset = SeqDataset(corpus_indices, num_steps)
    # 构建数据加载器
    data_iter = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_iter, vocab

In [9]:
batch_size, num_steps = 32, 35
data_iter, vocab = load_data_time_machine(batch_size, num_steps)

./data/timemachine.txt already exists.


In [15]:
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

In [16]:
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))