<a href="https://colab.research.google.com/github/ShovalBenjer/Natural_Language_Proccessing_NLP_Projects/blob/main/LSTM_Text_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install dependencies as needed:
!pip install kagglehub[pandas-datasets] torch torchvision torchaudio plotnine tqdm scikit-learn
import kagglehub
from kagglehub import KaggleDatasetAdapter

import pandas as pd
import numpy as np
import string
import os
import re

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

from tqdm.auto import tqdm # For progress bars

# For plotting
from plotnine import ggplot, aes, geom_line, labs, theme_minimal

# For reproducibility
import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
# --- 1. Load Data ---
print("Loading dataset...")

# Set the name of the file you'd like to load from within the Kaggle dataset.
# For 'aashita/nyt-comments', examples of comment files are:
# "CommentsApril2017.csv", "CommentsJan2018.csv", "CommentsFeb2017.csv", etc.
#
# If you set 'file_to_load_in_dataset' to an empty string (""),
# kagglehub will attempt to load a default file from the dataset.
# IMPORTANT: For 'aashita/nyt-comments', the default file might be an "Articles*.csv"
# file rather than a "Comments*.csv" file. To ensure you load comments,
# it's best to specify one of the comment files.
file_to_load_in_dataset = "CommentsApril2017.csv"
# To try kagglehub's default selection (BEWARE: might not be a comments file for this dataset):
# file_to_load_in_dataset = ""

try:
    df = kagglehub.load_dataset(
      KaggleDatasetAdapter.PANDAS,
      "aashita/nyt-comments",
      file_to_load_in_dataset,  # Pass as the third POSITIONAL argument
      # You can also provide pandas_kwargs here if needed, for example:
      # pandas_kwargs={'nrows': 10000} # To load only the first 10,000 rows for faster testing
    )

    # Determine which file was actually loaded for the print message
    loaded_file_msg = f"'{file_to_load_in_dataset}'" if file_to_load_in_dataset else "'default file'"
    print(f"Dataset file {loaded_file_msg} loaded successfully.")
    print("First 5 records:", df.head())

    # Select the comment text column.
    # For the NYT Comments files (e.g., CommentsApril2017.csv), the relevant column is 'commentBody'.
    if 'commentBody' in df.columns:
        corpus_raw = df['commentBody'].dropna().astype(str).tolist()
    elif 'commentText' in df.columns: # A common alternative name
        print("Warning: 'commentBody' not found. Using 'commentText' column.")
        corpus_raw = df['commentText'].dropna().astype(str).tolist()
    else:
        # If specific known columns are not found, try a more general fallback.
        # This part is less likely to be hit if 'file_to_load_in_dataset' correctly points to a comments CSV.
        object_cols = df.select_dtypes(include=['object']).columns
        potential_text_cols = [col for col in object_cols if df[col].astype(str).str.len().mean() > 30] # Heuristic: find long text columns

        if potential_text_cols:
            chosen_col = potential_text_cols[0]
            print(f"Warning: Neither 'commentBody' nor 'commentText' found. Using heuristic choice: '{chosen_col}'.")
            corpus_raw = df[chosen_col].dropna().astype(str).tolist()
        else:
            raise ValueError(f"No suitable text column (like 'commentBody') found in the loaded file: {loaded_file_msg}. "
                             "Please check the file content or the 'file_to_load_in_dataset' variable.")

except Exception as e:
    print(f"Error loading dataset or identifying comment column: {e}")
    print("Using dummy data for demonstration purposes as dataset loading failed.")
    corpus_raw = [
        "This is the first dummy comment about politics and news.",
        "I agree with the previous dummy sentiment regarding current events.",
        "Another dummy comment just to have some text.",
        "Let's talk about technology and AI, even in this dummy dataset.",
        "The future of AI is fascinating and often scary, this is a dummy thought."
    ]

Loading dataset...




Downloading from https://www.kaggle.com/api/v1/datasets/download/aashita/nyt-comments?dataset_version_number=13&file_name=CommentsApril2017.csv...


100%|██████████| 53.8M/53.8M [00:03<00:00, 14.9MB/s]

Extracting zip of CommentsApril2017.csv...





Dataset file 'CommentsApril2017.csv' loaded successfully.
First 5 records:    approveDate                                        commentBody   commentID  \
0   1491245186  This project makes me happy to be a 30+ year T...  22022598.0   
1   1491188619  Stunning photos and reportage. Infuriating tha...  22017350.0   
2   1491188617  Brilliant work from conception to execution. I...  22017334.0   
3   1491167820  NYT reporters should provide a contributor's l...  22015913.0   
4   1491167815     Could only have been done in print. Stunning.   22015466.0   

   commentSequence commentTitle commentType    createDate  depth  \
0         22022598        <br/>     comment  1.491237e+09      1   
1         22017350          NaN     comment  1.491180e+09      1   
2         22017334        <br/>     comment  1.491179e+09      1   
3         22015913        <br/>     comment  1.491150e+09      1   
4         22015466        <br/>     comment  1.491147e+09      1   

   editorsSelection  parentID

In [3]:
# --- Hyperparameters for Preprocessing ---
# MAX_CORPUS_SAMPLES was defined earlier where corpus_raw is sampled.
# Ensure it's set to a reasonable number for your hardware (e.g., 10000, 50000)
# Example: MAX_SAMPLES = 20000 # (Make sure this is applied before this block)

MAX_VOCAB_SIZE = 20000  # Max number of words in vocabulary (e.g., 20k-50k)
MAX_TRAINING_SEQUENCES = 500000 # Max number of (predictor, label) pairs (e.g., 500k-1M)
MAX_PREDICTOR_LEN = 50     # Max length of a predictor sequence fed to LSTM (e.g., 30-100 words)


# --- 2. Preprocess Data ---
print("\nPreprocessing data...")

def clean_text(txt):
    txt = txt.lower()
    txt = re.sub(r'http\S+|www\S+|https\S+', '', txt, flags=re.MULTILINE)
    txt = re.sub(r'\@\w+|\#','', txt)
    txt = "".join(v for v in txt if v not in string.punctuation.replace("'", ""))
    txt = txt.replace("’", "'")
    emoji_pattern = re.compile("["
        u"\U0001F600-\U0001F64F" u"\U0001F300-\U0001F5FF" u"\U0001F680-\U0001F6FF"
        u"\U0001F1E0-\U0001F1FF" u"\U00002702-\U000027B0" u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    txt = emoji_pattern.sub(r'', txt)
    try:
        txt = txt.encode("utf8").decode("ascii", 'ignore')
    except TypeError: # If already bytes
        txt = txt.decode("ascii", 'ignore')
    txt = re.sub(r'\s+', ' ', txt).strip()
    return txt

# Ensure MAX_SAMPLES is applied to corpus_raw before this point if it wasn't already.
# If corpus_raw comes directly from df['commentBody'].tolist(), sample it here:
# if len(corpus_raw) > MAX_SAMPLES: # MAX_SAMPLES defined near data loading
#     print(f"Reducing corpus_raw from {len(corpus_raw)} to {MAX_SAMPLES} for preprocessing.")
#     corpus_raw = random.sample(corpus_raw, MAX_SAMPLES)


corpus_cleaned = [clean_text(text) for text in tqdm(corpus_raw, desc="Cleaning text") if clean_text(text)]
corpus_cleaned = [text for text in corpus_cleaned if len(text.split()) >= 3] # Min 3 words for a meaningful sequence

print(f"Number of cleaned comments: {len(corpus_cleaned)}")
if not corpus_cleaned:
    raise ValueError("Corpus is empty after cleaning. Check data or cleaning steps.")
print(f"Sample cleaned comment: {corpus_cleaned[0]}")


class TokenizerWithPaddingAndVocabLimit:
    def __init__(self, oov_token="<unk>", pad_token="<pad>", max_vocab_size=None):
        self.oov_token = oov_token
        self.pad_token = pad_token
        self.max_vocab_size = max_vocab_size

        self.word_to_idx = {self.pad_token: 0, self.oov_token: 1}
        self.idx_to_word = {0: self.pad_token, 1: self.oov_token}
        self.pad_idx = 0
        self.oov_idx = 1
        self._next_idx = 2
        self.word_counts = {}

    def fit_on_texts(self, texts):
        for text in texts:
            for word in text.split():
                self.word_counts[word] = self.word_counts.get(word, 0) + 1

        if self.max_vocab_size is not None and len(self.word_counts) > (self.max_vocab_size - self._next_idx) :
            # Sort words by frequency, keep top N
            sorted_words = sorted(self.word_counts.keys(), key=lambda x: self.word_counts[x], reverse=True)
            words_to_keep = sorted_words[:self.max_vocab_size - self._next_idx] # Reserve space for pad/oov
        else:
            # Keep all words if below limit or no limit
            words_to_keep = sorted(self.word_counts.keys(), key=lambda x: self.word_counts[x], reverse=True)

        for word in words_to_keep:
            if word not in self.word_to_idx: # Should not happen if logic is correct
                 self._add_word_to_vocab(word)
        print(f"Vocabulary built. Kept {len(self.word_to_idx) - self._next_idx} most frequent words + special tokens.")


    def _add_word_to_vocab(self, word):
        # This is called by fit_on_texts after deciding which words to keep
        if word not in self.word_to_idx: # Should always be true here
            self.word_to_idx[word] = self._next_idx
            self.idx_to_word[self._next_idx] = word
            self._next_idx += 1

    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            seq = [self.word_to_idx.get(word, self.oov_idx) for word in text.split()]
            sequences.append(seq)
        return sequences

    def sequences_to_texts(self, sequences): # For generation/debugging
        texts = []
        for seq in sequences:
            words = [self.idx_to_word.get(idx, self.oov_token) for idx in seq]
            texts.append(" ".join(words))
        return texts

    @property
    def vocab_size(self):
        return self._next_idx


tokenizer = TokenizerWithPaddingAndVocabLimit(max_vocab_size=MAX_VOCAB_SIZE)
tokenizer.fit_on_texts(corpus_cleaned)
total_words = tokenizer.vocab_size
print(f"Total unique words in vocabulary (incl. <pad>, <unk>, limited by MAX_VOCAB_SIZE): {total_words}")

predictors_list = []
raw_labels_list = [] # Store integer labels first

print(f"Generating max {MAX_TRAINING_SEQUENCES} sequences with max predictor length {MAX_PREDICTOR_LEN}...")
num_sequences_generated = 0
need_to_break_all_loops = False

for line in tqdm(corpus_cleaned, desc="Generating sequences"):
    if need_to_break_all_loops:
        break
    token_list = tokenizer.texts_to_sequences([line])[0]

    if not token_list or all(t == tokenizer.oov_idx or t == tokenizer.pad_idx for t in token_list):
        continue # Skip if empty or all OOV/PAD after tokenization

    for i in range(1, len(token_list)): # Create n-grams
        if num_sequences_generated >= MAX_TRAINING_SEQUENCES:
            need_to_break_all_loops = True
            break # Break from inner loop (n-grams of one sentence)

        n_gram_sequence = token_list[:i+1]
        predictor_part_tokens = n_gram_sequence[:-1]
        label_token = n_gram_sequence[-1]

        if len(predictor_part_tokens) > 0:
            # Truncate predictor_part_tokens if it's too long (take the tail end)
            if len(predictor_part_tokens) > MAX_PREDICTOR_LEN:
                predictor_part_tokens = predictor_part_tokens[-MAX_PREDICTOR_LEN:]

            predictors_list.append(torch.tensor(predictor_part_tokens, dtype=torch.long))
            raw_labels_list.append(label_token)
            num_sequences_generated += 1

print(f"Generated {num_sequences_generated} sequences.")

if not predictors_list:
    raise ValueError("No predictor sequences generated. Check MAX_TRAINING_SEQUENCES, data, or preprocessing logic.")

# Pad predictor sequences
# `batch_first=True` means pad_sequence will output (batch_size, actual_max_len_in_batch)
# The actual_max_len_in_batch will be <= MAX_PREDICTOR_LEN due to our truncation.
padded_predictors = pad_sequence(predictors_list, batch_first=True, padding_value=tokenizer.pad_idx)
labels_tensor = torch.tensor(raw_labels_list, dtype=torch.long)

max_sequence_len_model_input = padded_predictors.shape[1] # Actual max length after padding
print(f"Max predictor sequence length for model input (after padding/truncation): {max_sequence_len_model_input}")

print(f"Shape of padded predictors: {padded_predictors.shape}") # e.g., (num_sequences_generated, max_sequence_len_model_input)
print(f"Shape of labels: {labels_tensor.shape}") # e.g., (num_sequences_generated,)


Preprocessing data...


Cleaning text:   0%|          | 0/243832 [00:00<?, ?it/s]

Number of cleaned comments: 242065
Sample cleaned comment: this project makes me happy to be a 30 year times subscriber continue to innovate across all platforms please
Vocabulary built. Kept 0 most frequent words + special tokens.
Total unique words in vocabulary (incl. <pad>, <unk>, limited by MAX_VOCAB_SIZE): 20000
Generating max 500000 sequences with max predictor length 50...


Generating sequences:   0%|          | 0/242065 [00:00<?, ?it/s]

Generated 500000 sequences.
Max predictor sequence length for model input (after padding/truncation): 50
Shape of padded predictors: torch.Size([500000, 50])
Shape of labels: torch.Size([500000])


In [None]:
# --- 3. Create PyTorch Dataset and DataLoader ---
class CommentDataset(Dataset):
    def __init__(self, predictors, labels):
        self.predictors = predictors
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.predictors[idx], self.labels[idx]

dataset = CommentDataset(padded_predictors, labels_tensor)

# Split into training and validation (optional, but good practice)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
if val_size == 0 and train_size > 0 : # handle tiny datasets for testing
    train_dataset = dataset
    val_dataset = dataset # use train as val if val_size is 0
else:
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


BATCH_SIZE = 128 # Can be tuned
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
if val_size > 0:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
else: # if val_size is 0, val_loader can be None or point to train_loader for simplicity in eval loop
    val_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)


# --- 4. Define LSTM Model in PyTorch ---
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, padding_idx):
        super(LanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers,
                            batch_first=True, dropout=dropout_rate if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        embedded = self.embedding(x)
        # embedded shape: (batch_size, seq_len, embedding_dim)

        lstm_out, _ = self.lstm(embedded)
        # lstm_out shape: (batch_size, seq_len, hidden_dim)

        # We only want the output from the last time step
        last_lstm_out = lstm_out[:, -1, :]
        # last_lstm_out shape: (batch_size, hidden_dim)

        out = self.dropout(last_lstm_out)
        out = self.fc(out)
        # out shape: (batch_size, vocab_size)
        return out

# Hyperparameters
EMBEDDING_DIM = 100
HIDDEN_DIM = 150
NUM_LAYERS = 2 # Using 2 LSTM layers
DROPOUT_RATE = 0.2
LEARNING_RATE = 0.001 # Adam's default is 0.001
EPOCHS = 20 # Start with a smaller number for quick testing

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = LanguageModel(
    vocab_size=total_words,
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout_rate=DROPOUT_RATE,
    padding_idx=tokenizer.pad_idx
).to(device)

print(model)
# Sanity check one batch
try:
    sample_x, sample_y = next(iter(train_loader))
    sample_x, sample_y = sample_x.to(device), sample_y.to(device)
    output = model(sample_x)
    print("Sample output shape:", output.shape) # Expected: (BATCH_SIZE, total_words)
    print("Sample target shape:", sample_y.shape) # Expected: (BATCH_SIZE)
except StopIteration:
    print("Train loader is empty. Cannot perform sanity check.")
    # This can happen if MAX_SAMPLES is too small or data cleaning is too aggressive.

# --- 5. Training Loop ---
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_idx) # Ignore padding in loss calculation
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'perplexity': []}

print("\nStarting training...")
if not train_loader:
    print("Skipping training as no data is available in train_loader.")
else:
    for epoch in range(EPOCHS):
        model.train()
        epoch_train_loss = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Training]", leave=False)
        for batch_predictors, batch_labels in progress_bar:
            batch_predictors, batch_labels = batch_predictors.to(device), batch_labels.to(device)

            optimizer.zero_grad()
            outputs = model(batch_predictors)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_train_loss = epoch_train_loss / len(train_loader)
        history['epoch'].append(epoch + 1)
        history['train_loss'].append(avg_train_loss)

        # Validation
        model.eval()
        epoch_val_loss = 0
        if val_loader and len(val_loader) > 0: # Check if val_loader is not empty
            with torch.no_grad():
                progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Validation]", leave=False)
                for batch_predictors, batch_labels in progress_bar_val:
                    batch_predictors, batch_labels = batch_predictors.to(device), batch_labels.to(device)
                    outputs = model(batch_predictors)
                    loss = criterion(outputs, batch_labels)
                    epoch_val_loss += loss.item()
                    progress_bar_val.set_postfix(loss=loss.item())
            avg_val_loss = epoch_val_loss / len(val_loader)
            perplexity = np.exp(avg_val_loss) # Perplexity = e^(avg_cross_entropy_loss)
        else: # Handle case with no validation data
            avg_val_loss = float('nan')
            perplexity = float('nan')

        history['val_loss'].append(avg_val_loss)
        history['perplexity'].append(perplexity)

        print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f} - Perplexity: {perplexity:.2f}")

Using device: cuda
LanguageModel(
  (embedding): Embedding(20000, 100, padding_idx=0)
  (lstm): LSTM(100, 150, num_layers=2, batch_first=True, dropout=0.2)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=150, out_features=20000, bias=True)
)
Sample output shape: torch.Size([128, 20000])
Sample target shape: torch.Size([128])

Starting training...


Epoch 1/20 [Training]:   0%|          | 0/3516 [00:00<?, ?it/s]

Epoch 1/20 [Validation]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1/20 - Train Loss: 6.8202 - Val Loss: 6.6290 - Perplexity: 756.72


Epoch 2/20 [Training]:   0%|          | 0/3516 [00:00<?, ?it/s]

Epoch 2/20 [Validation]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2/20 - Train Loss: 6.4297 - Val Loss: 6.3089 - Perplexity: 549.46


Epoch 3/20 [Training]:   0%|          | 0/3516 [00:00<?, ?it/s]

Epoch 3/20 [Validation]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3/20 - Train Loss: 6.1518 - Val Loss: 6.1618 - Perplexity: 474.30


Epoch 4/20 [Training]:   0%|          | 0/3516 [00:00<?, ?it/s]

Epoch 4/20 [Validation]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4/20 - Train Loss: 5.9711 - Val Loss: 6.0854 - Perplexity: 439.41


Epoch 5/20 [Training]:   0%|          | 0/3516 [00:00<?, ?it/s]

Epoch 5/20 [Validation]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5/20 - Train Loss: 5.8303 - Val Loss: 6.0370 - Perplexity: 418.66


Epoch 6/20 [Training]:   0%|          | 0/3516 [00:00<?, ?it/s]

In [None]:
# --- 6. Text Generation Function ---
def generate_text(seed_text, next_words, model, tokenizer, max_sequence_len_model_input, device, temperature=1.0):
    model.eval()
    generated_text = seed_text

    for _ in range(next_words):
        # Tokenize current text
        token_list = tokenizer.texts_to_sequences([generated_text.lower()])[0]

        # Pad sequence (take only the last `max_sequence_len_model_input` tokens)
        if len(token_list) > max_sequence_len_model_input:
            token_list = token_list[-max_sequence_len_model_input:]

        # Convert to tensor and add batch dimension
        input_tensor = torch.tensor([token_list], dtype=torch.long).to(device)

        with torch.no_grad():
            output = model(input_tensor) # Shape: (1, vocab_size)

        # Apply temperature to logits
        output_dist = output.data.view(-1).div(temperature).exp()

        # Sample from the distribution or take argmax
        # For more diverse generation, use sampling:
        # top_i = torch.multinomial(output_dist, 1)[0]

        # For deterministic (but potentially repetitive) generation, use argmax:
        if temperature == 0.0: # Pure argmax
            predicted_idx = torch.argmax(output, dim=1).item()
        else: # Sampling with temperature
            predicted_idx = torch.multinomial(output_dist, 1)[0].item()


        # Handle OOV or PAD prediction if they occur (less likely with good training)
        if predicted_idx == tokenizer.pad_idx and tokenizer.pad_token != "<pad_is_word>": # Avoid predicting padding
            # Fallback: predict the next most likely non-pad token
            sorted_preds = torch.argsort(output, dim=1, descending=True)
            for idx_val in sorted_preds[0]:
                if idx_val.item() != tokenizer.pad_idx:
                    predicted_idx = idx_val.item()
                    break

        output_word = tokenizer.idx_to_word.get(predicted_idx, tokenizer.oov_token)

        if output_word == tokenizer.oov_token: # Don't append <unk> if it's just a placeholder
            continue # Or break, or try another prediction

        generated_text += " " + output_word

    return generated_text.title() # Capitalize like the Keras example

print("\n--- Text Generation Examples ---")
if not train_loader:
    print("Skipping generation as model was not trained.")
else:
    try:
        seed1 = "the president said"
        seed2 = "new york is"
        seed3 = "climate change will"

        print(f"Seed: '{seed1}'")
        print("Generated (temp=0.7):", generate_text(seed1, 10, model, tokenizer, max_sequence_len_model_input, device, temperature=0.7))
        print("Generated (temp=1.0):", generate_text(seed1, 10, model, tokenizer, max_sequence_len_model_input, device, temperature=1.0))
        print("Generated (argmax):", generate_text(seed1, 10, model, tokenizer, max_sequence_len_model_input, device, temperature=0.0)) # temperature 0 for argmax

        print(f"\nSeed: '{seed2}'")
        print("Generated (temp=0.7):", generate_text(seed2, 10, model, tokenizer, max_sequence_len_model_input, device, temperature=0.7))

        print(f"\nSeed: '{seed3}'")
        print("Generated (temp=0.7):", generate_text(seed3, 10, model, tokenizer, max_sequence_len_model_input, device, temperature=0.7))
    except Exception as e:
        print(f"Error during text generation: {e}")
        print("This might be due to a very small vocabulary or issues with sequence lengths.")


# --- 7. Evaluation (Perplexity is already calculated during training) ---
# Qualitative evaluation is looking at the generated text.
# Quantitative:
# - Perplexity on a test set (calculated as exp(average cross-entropy loss on test set)).
# - BLEU scores if comparing to reference continuations (more for machine translation but adaptable).

print("\n--- Evaluation ---")
print("Perplexity on the validation set is tracked during training.")
print("Final validation perplexity (if available):", f"{history['perplexity'][-1]:.2f}" if history['perplexity'] and not np.isnan(history['perplexity'][-1]) else "N/A")
print("Qualitative evaluation: Inspect the generated text samples above.")


# --- 8. Plotting with plotnine ---
print("\n--- Plotting Training Progress ---")
if history['epoch']: # Check if history has data
    history_df = pd.DataFrame(history)

    # Plot training and validation loss
    loss_plot = (
        ggplot(history_df.melt(id_vars=['epoch'], value_vars=['train_loss', 'val_loss'], var_name='metric', value_name='loss'),
               aes(x='epoch', y='loss', color='metric')) +
        geom_line() +
        labs(title="Training and Validation Loss", x="Epoch", y="Loss") +
        theme_minimal()
    )
    print(loss_plot)

    # Plot perplexity
    if 'perplexity' in history_df.columns and history_df['perplexity'].notna().any():
        perplexity_plot = (
            ggplot(history_df[history_df['perplexity'].notna()], # Filter out NaN perplexity if val_loader was empty
                   aes(x='epoch', y='perplexity')) +
            geom_line(color="blue") +
            labs(title="Validation Perplexity", x="Epoch", y="Perplexity") +
            theme_minimal()
        )
        print(perplexity_plot)
else:
    print("No training history to plot (e.g., training was skipped).")

print("\nModel training and evaluation complete.")
print("To improve: train for more epochs, use more data, tune hyperparameters (embedding_dim, hidden_dim, layers, dropout, learning rate), or try a character-level model for very large vocabularies.")