<a href="https://colab.research.google.com/github/akkajjy/MyNLPLab/blob/main/notebooks/lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

# 准备数据
text = """
The quick brown fox jumps over the lazy dog.
The cat runs fast and jumps high.
Cats and dogs play together in the yard.
""".lower()
chars = sorted(list(set(text)))
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}
vocab_size = len(chars)

# 将文本转为索引序列
sequence = [char_to_idx[c] for c in text]
seq_length = 10  # 每次输入的序列长度

# 创建训练数据
def create_dataset(text, seq_length):
    inputs, targets = [], []
    for i in range(0, len(text) - seq_length):
        inputs.append([char_to_idx[c] for c in text[i:i+seq_length]])
        targets.append(char_to_idx[text[i+seq_length]])
    return torch.tensor(inputs, dtype=torch.long), torch.tensor(targets, dtype=torch.long)

inputs, targets = create_dataset(text, seq_length)

# 定义 LSTM 模型
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, hidden_dim, n_layers=1):
        super(CharLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden):
        x = self.embed(x)  # [batch, seq_length] -> [batch, seq_length, hidden_dim]
        out, hidden = self.lstm(x, hidden)  # out: [batch, seq_length, hidden_dim]
        out = self.fc(out)  # [batch, seq_length, vocab_size]
        return out, hidden

    def init_hidden(self, batch_size):
        return (torch.zeros(self.n_layers, batch_size, self.hidden_dim),
                torch.zeros(self.n_layers, batch_size, self.hidden_dim))

# 参数
hidden_dim = 20
n_layers = 1
model = CharLSTM(vocab_size, hidden_dim, n_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training
epochs = 100
batch_size = 32
for epoch in range(epochs):
    model.train()
    # Initialize hidden state for the entire epoch (adjust if using stateful LSTM across batches)
    # Note: For non-stateful LSTM like this, re-initializing per batch is common.
    # Let's keep re-initializing per batch as originally coded.

    total_loss = 0
    num_batches = 0 # Keep track of number of batches processed

    # Create data loader for easier batching and shuffling (optional but good practice)
    # from torch.utils.data import TensorDataset, DataLoader
    # train_data = TensorDataset(inputs, targets)
    # train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
    # for batch_inputs, batch_targets in train_loader: # If using DataLoader

    # Original batching loop (works fine too)
    hidden = model.init_hidden(batch_size) # Initialize hidden state for the first batch
    for i in range(0, len(inputs), batch_size):
        end_idx = i + batch_size
        # Handle last batch potentially being smaller
        actual_batch_size = min(batch_size, len(inputs) - i)
        if actual_batch_size == 0:
             continue

        batch_inputs = inputs[i:end_idx]
        batch_targets = targets[i:end_idx]

        # Re-initialize hidden state if the batch size changes (important for the last batch)
        # Or initialize hidden state here before processing each batch
        if i == 0 or actual_batch_size != hidden[0].size(1): # If first batch or size changed
            hidden = model.init_hidden(actual_batch_size)

        optimizer.zero_grad()

        # Detach hidden state from the previous batch's history
        hidden = tuple(h.detach() for h in hidden)

        output, hidden = model(batch_inputs, hidden)

        # --- Correction is Here ---
        # Select the output corresponding to the last time step for each sequence
        last_step_output = output[:, -1, :]
        loss = criterion(last_step_output, batch_targets)
        # --------------------------

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1

    if (epoch + 1) % 20 == 0 and num_batches > 0: # Avoid division by zero if dataset is small
        print(f"Epoch {epoch+1}, Loss: {total_loss / num_batches}")
    elif num_batches == 0:
         print(f"Epoch {epoch+1}, No batches processed (check dataset size and batch size)")


# --- Small fix in generation for robustness ---
def generate_text(model, start_text, max_length=100, seq_length=10): # Pass seq_length
    model.eval()
    chars_idx = [char_to_idx[c] for c in start_text]
    hidden = model.init_hidden(1)
    result = list(start_text)

    # "Warm up" the hidden state with the initial start_text
    if len(start_text) > 0:
        # Prepare the full start_text as input tensor
        input_tensor = torch.tensor([chars_idx], dtype=torch.long)
        with torch.no_grad():
             _, hidden = model(input_tensor, hidden) # Update hidden state based on start_text

    with torch.no_grad():
        # Use the last character of start_text as the first input for generation loop
        # Or predict based on the warmed-up hidden state
        current_char_idx = chars_idx[-1] if chars_idx else char_to_idx.get(' ', 0) # Handle empty start_text

        for _ in range(max_length - len(start_text)):
            # Input tensor now only needs the *last* character processed
            input_tensor = torch.tensor([[current_char_idx]], dtype=torch.long)

            output, hidden = model(input_tensor, hidden)
            # Get probabilities for the *single* output step
            probs = torch.softmax(output.squeeze(0), dim=-1).squeeze() # Shape [vocab_size]

            # Handle potential dimension issues if probs becomes 0-dim
            if probs.dim() == 0:
                probs = probs.unsqueeze(0)

            # Avoid error if probs sums to 0 (unlikely but possible with numerical issues)
            if probs.sum() == 0:
                 # Fallback: choose a random character or the most frequent one
                 next_idx = random.choice(range(vocab_size))
                 print("Warning: Probability sum is zero, choosing random character.")
            else:
                next_idx = torch.multinomial(probs, 1).item()

            result.append(idx_to_char[next_idx])
            current_char_idx = next_idx # Update the input for the next step

    return ''.join(result)

# Testing (ensure start_text characters are in vocab)
valid_start_chars = [c for c in "the quick" if c in char_to_idx]
if len(valid_start_chars) < len("the quick"):
    print("Warning: Some characters in start_text are not in the training vocabulary.")

start_text_clean = "".join(valid_start_chars)
if not start_text_clean: # Handle case where start_text is empty or all invalid chars
    start_text_clean = idx_to_char[random.choice(range(vocab_size))] # Start with a random char
    print(f"Warning: Invalid start_text provided, starting with '{start_text_clean}'")


print("Generated text:", generate_text(model, start_text_clean, seq_length=seq_length)) # Pass seq_length

Epoch 20, Loss: 0.6719115898013115
Epoch 40, Loss: 0.11884616315364838
Epoch 60, Loss: 0.03753652097657323
Epoch 80, Loss: 0.019954499322921038
Epoch 100, Loss: 0.012621685164049268
Generated text: the quickn jompgs h.ecat uuns fox jumps ovee yatsr and theghegheche cats as tooger.
the eayyat ard.

