# Next-word Generator for the Sherlock Holmes dataset


## Imports and Initial Configuration

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import pandas as pd
import plotly.express as px  # For interactive plotting
import plotly.graph_objects as go
import re
import os
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# Set up Plotly for better visualization
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

# Display PyTorch version and set device
print(f"PyTorch Version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define dataset directory
dataset_dir = os.path.join(os.getcwd(), 'datasets')


PyTorch Version: 2.2.1+cu121
Using device: cuda


## Text Cleaning Function

We define a function to clean the text by handling punctuation more effectively and ensuring case insensitivity.

In [2]:
# %%
def clean_text(filename: str):
    """
    Reads and cleans text from a file.
    Handles punctuation by separating them as distinct tokens,
    incorporates paragraph boundaries, and converts all text to lowercase.
    """
    filepath = os.path.join(dataset_dir, filename)
    with open(filepath, encoding='utf-8') as file:
        text = file.read()
    
    # Replace paragraph breaks with special tokens
    # Assuming paragraphs are separated by two or more newlines
    text = re.sub(r'\n\s*\n', ' <PAR_END> <PAR_START> ', text)
    
    # Separate specified punctuation by adding spaces around them
    # Punctuation marks: ., ,, !, ?, -
    text = re.sub(r'([.,!?-])', r' \1 ', text)
    
    # Remove any unwanted characters except specified punctuation and alphanumerics
    text = re.sub(r'[^a-zA-Z0-9\s.,!?-]', ' ', text)
    
    # Convert to lowercase for case insensitivity
    text = text.lower()
    
    # Remove extra spaces
    text = " ".join(text.split())
    
    return text


## Unique Words Extraction Function

Extracts unique words, including punctuation, and creates mappings between words and their indices.

In [3]:
# %%
def unique_words(text: str):
    """
    Extracts unique words and punctuation from the text.
    Creates mappings from string to index and vice versa.
    Includes special tokens for paragraph boundaries and <UNK>.
    """
    words = pd.Series(text.split())
    
    # Define allowed punctuation marks
    allowed_punctuations = {'.', ',', '!', '?', '-'}
    
    # Filter words by length and ensure they are alphanumeric or specified punctuation
    words = words[((words.str.len() > 0) & (words.str.len() < 20))]
    words = words[words.isin(allowed_punctuations) | words.str.match(r'^[a-zA-Z0-9]+$')]
    
    # Drop duplicates and sort
    words = words.drop_duplicates(ignore_index=True)
    vocab = words.sort_values().to_list()
    
    # Initialize stoi with special tokens
    special_tokens = ['<par_start>', '<par_end>', '<UNK>']
    stoi = {token: i + 1 for i, token in enumerate(special_tokens)}
    
    # Add the remaining vocabulary words, starting from the next available index
    next_index = len(stoi) + 1
    for word in vocab:
        if word not in stoi:
            stoi[word] = next_index
            next_index += 1
    
    # Create the itos mapping based on updated stoi
    itos = {i: s for s, i in stoi.items()}
    
    return vocab, stoi, itos


## Data Preparation
We prepare the dataset by creating input-output pairs based on a context window.

In [4]:
def prepare_data(text: str, block_size: int, stoi):
    """
    Prepares input-output pairs for training.
    Each input consists of `block_size` tokens, and the target is the next token.
    Unknown words are mapped to the <UNK> token.
    """
    words = text.split()
    X, Y = [], []
    
    # Define the <UNK> token
    unk_token = '<UNK>'
    if unk_token not in stoi:
        stoi[unk_token] = len(stoi) + 1
    unk_idx = stoi[unk_token]
    
    for i in range(block_size, len(words)):
        context = words[i-block_size:i]
        target = words[i]
        
        # Convert context and target to indices, map unknown words to <UNK>
        context_ix = [stoi.get(word, unk_idx) for word in context]
        target_ix = stoi.get(target, unk_idx)
        
        X.append(context_ix)
        Y.append(target_ix)
    
    # Convert lists to tensors
    X = torch.tensor(X, dtype=torch.long).to(device)
    Y = torch.tensor(Y, dtype=torch.long).to(device)
    
    return X, Y


## Data Cleaning and Preparation
We clean the text, build the vocabulary, and prepare the data for training.

In [5]:
# %%
# Clean the text from the dataset
text = clean_text('/teamspace/studios/this_studio/es335-24-fall-assignment-3/datasets/sherlock.txt')

# Extract unique words and create mappings
vocab, stoi, itos = unique_words(text)

# Add the <UNK> token to itos
itos[stoi['<UNK>']] = '<UNK>'

# Prepare input-output pairs with a context window of 5
block_size = 5
X, Y = prepare_data(text, block_size, stoi)

# Display the shapes of the tensors
print(f"Input shape: {X.shape}, dtype: {X.dtype}")
print(f"Target shape: {Y.shape}, dtype: {Y.dtype}")

words = text.split()
total_words = len(words)
unique_words = len(set(words))

print(f"Total number of words: {total_words}")
print(f"Total number of unique words: {unique_words}")


Input shape: torch.Size([135639, 5]), dtype: torch.int64
Target shape: torch.Size([135639]), dtype: torch.int64
Total number of words: 135644
Total number of unique words: 8156


In [6]:
# %%
# Number of samples to display
num_samples = 20

print(f"Displaying {num_samples} sample input-output pairs:\n")

# Iterate over the first num_samples and print context and target
for i in range(num_samples):
    # Get context indices and target index
    context_indices = X[i].tolist()
    target_index = Y[i].item()
    
    # Convert indices back to words, using <UNK> for unknown indices
    context_words = [itos.get(idx, '<UNK>') for idx in context_indices]
    target_word = itos.get(target_index, '<UNK>')
    
    # Print the context and target
    print(f"--- Sample {i+1} ---")
    print(f"Context: {' '.join(context_words)}")
    print(f"Target: {target_word}\n")


Displaying 20 sample input-output pairs:

--- Sample 1 ---
Context: the project gutenberg ebook of
Target: the

--- Sample 2 ---
Context: project gutenberg ebook of the
Target: adventures

--- Sample 3 ---
Context: gutenberg ebook of the adventures
Target: of

--- Sample 4 ---
Context: ebook of the adventures of
Target: sherlock

--- Sample 5 ---
Context: of the adventures of sherlock
Target: holmes

--- Sample 6 ---
Context: the adventures of sherlock holmes
Target: ,

--- Sample 7 ---
Context: adventures of sherlock holmes ,
Target: by

--- Sample 8 ---
Context: of sherlock holmes , by
Target: arthur

--- Sample 9 ---
Context: sherlock holmes , by arthur
Target: conan

--- Sample 10 ---
Context: holmes , by arthur conan
Target: doyle

--- Sample 11 ---
Context: , by arthur conan doyle
Target: par

--- Sample 12 ---
Context: by arthur conan doyle par
Target: end

--- Sample 13 ---
Context: arthur conan doyle par end
Target: par

--- Sample 14 ---
Context: conan doyle par end par
Targe

## Embedding Initialization and Visualization

We initialize the embedding layer and visualize the embeddings using t-SNE with Plotly.

In [7]:
# %%
# Define hyperparameters
embedding_dim = 256  # Increased embedding size

# Initialize the embedding layer
embedding = nn.Embedding(len(stoi) + 1, embedding_dim).to(device)  # +1 for padding if needed
print(f"Embedding Weights Shape: {embedding.weight.shape}")

# Convert embeddings to NumPy for visualization
embeddings = embedding.weight.detach().cpu().numpy()

# Perform t-SNE to reduce dimensions to 2D
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings)

# Create a DataFrame for Plotly
words = list(stoi.keys())
df = pd.DataFrame({
    'word': words,
    'x': embeddings_2d[:len(words), 0],
    'y': embeddings_2d[:len(words), 1]
})

# Plot using Plotly Express
fig = px.scatter(
    df, 
    x='x', 
    y='y', 
    hover_name='word',  # Display word only on hover
    title='t-SNE Visualization of Word Embeddings',
    width=800,
    height=800
)

# Customize the plot to hide text labels and show only dots
fig.update_traces(
    marker=dict(size=5, color='blue'),  # Customize dot color and size
    hovertemplate='<b>%{hovertext}</b><extra></extra>'  # Show only the word on hover
)
fig.update_layout(
    title=dict(x=0.5),
    xaxis_title="t-SNE Dimension 1",
    yaxis_title="t-SNE Dimension 2",
    template='plotly_white'
)

fig.show()

Embedding Weights Shape: torch.Size([8160, 256])


KeyboardInterrupt: 

## MLP for next word prediction
We define the neural network model with an increased hidden layer size.

In [None]:
hidden_dim = 512      # Increased hidden layer size
epochs = 500
learning_rate = 0.001

class NextWord(nn.Module):
    """
    A feedforward neural network with multiple hidden layers for next-word prediction.
    Utilizes Leaky ReLU activation functions to improve gradient flow.
    """
    def __init__(self, block_size, vocab_size, embedding_dim, hidden_dim):
        super(NextWord, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)  # Embedding layer
        
        # Define multiple hidden layers
        self.lin1 = nn.Linear(embedding_dim * block_size, hidden_dim)   # First hidden layer
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)                  # Second hidden layer
        self.lin3 = nn.Linear(hidden_dim, hidden_dim)                  # Third hidden layer
        self.lin4 = nn.Linear(hidden_dim, hidden_dim)                  # Third hidden layer
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.1)             # Leaky ReLU activation
        self.lin_out = nn.Linear(hidden_dim, vocab_size)               # Output layer

    def forward(self, x):
        embeds = self.embedding(x)                                     # Shape: [batch_size, block_size, embedding_dim]
        embeds = embeds.view(x.shape[0], -1)                          # Flatten: [batch_size, block_size * embedding_dim]
        
        out = self.lin1(embeds)                                        # First hidden layer
        out = self.leaky_relu(out)                                     # Activation
        
        out = self.lin2(out)                                           # Second hidden layer
        out = self.leaky_relu(out)                                     # Activation
        
        out = self.lin3(out)                                           # Third hidden layer
        out = self.leaky_relu(out)                                     # Activation

        out = self.lin4(out)                                           # Fourth hidden layer
        out = self.leaky_relu(out)                                     # Activation
        
        out = self.lin_out(out)                                        # Output layer: [batch_size, vocab_size]
        return out
        
# Initialize the model with increased hidden dimension
model = NextWord(block_size, len(stoi) + 1, embedding_dim, hidden_dim).to(device)
print(model)

def generate_sequence(model, itos, stoi, context_words, block_size, max_len=20):
    """
    Generates a sequence of words based on the provided context, with support for
    paragraph tokens (`<par_start>` and `<par_end>`), punctuation handling, and capitalization.
    """
    model.eval()  # Set model to evaluation mode
    unk_idx = stoi.get('<UNK>', 0)  # Index for unknown words, default to 0 if not found
    par_start_idx = stoi.get('<par_start>', unk_idx)  # Index for paragraph start
    par_end_idx = stoi.get('<par_end>', unk_idx)  # Index for paragraph end

    # Convert context words to indices, map unknown words to `<UNK>`
    context = [stoi.get(word, unk_idx) for word in context_words]
    
    # Pad context if it's shorter than `block_size`
    if len(context) < block_size:
        context = [par_end_idx] * (block_size - len(context)) + context

    sequence = context_words.copy()
    capitalize_next = False  # Flag to determine if the next word should be capitalized
    
    # If the last word in the context ends with a period or is `<par_end>`, set the flag
    if sequence and (sequence[-1].endswith('.') or sequence[-1] == '<par_end>'):
        capitalize_next = True

    with torch.no_grad():
        for _ in range(max_len):
            # Prepare input tensor
            x = torch.tensor(context[-block_size:], dtype=torch.long).unsqueeze(0).to(device)  # Shape: [1, block_size]
            
            # Get model predictions
            y_pred = model(x)  # Shape: [1, vocab_size]
            
            # Apply softmax to get probabilities
            probs = F.softmax(y_pred, dim=1)
            
            # Sample from the distribution
            ix = torch.multinomial(probs, num_samples=1).item()
            word = itos.get(ix, '<UNK>')

            # Handle special tokens for paragraph breaks
            if word == '<par_start>':
                sequence.append('\n')  # Insert a newline for paragraph start
                capitalize_next = True  # Capitalize the next word after paragraph start
                continue  # Skip to the next iteration without updating context
            elif word == '<par_end>':
                sequence.append('\n')  # Insert a newline for paragraph end
                continue  # Skip to the next iteration without updating context

            # Determine if the word is punctuation
            is_punct = word in {'.', ',', '!', '?'}

            # Capitalize the word if the flag is set and it's not punctuation
            if capitalize_next and not is_punct:
                word = word.capitalize()
                capitalize_next = False  # Reset the flag after capitalizing

            # Append the generated word to the sequence
            sequence.append(word)

            # Update the context with the generated word's index
            context.append(ix)

            # If the generated word is a period, set the flag to capitalize the next word
            if word.endswith('.'):
                capitalize_next = True

    # Post-processing to remove spaces before punctuation marks
    generated_text = ' '.join(sequence)
    generated_text = re.sub(r'\s+([.,!?])', r'\1', generated_text)  # Remove space before punctuation
    return generated_text

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


NextWord(
  (embedding): Embedding(8160, 256)
  (lin1): Linear(in_features=1024, out_features=512, bias=True)
  (lin2): Linear(in_features=512, out_features=512, bias=True)
  (lin3): Linear(in_features=512, out_features=512, bias=True)
  (lin4): Linear(in_features=512, out_features=512, bias=True)
  (leaky_relu): LeakyReLU(negative_slope=0.1)
  (lin_out): Linear(in_features=512, out_features=8160, bias=True)
)


In [None]:
for param_name, param in model.named_parameters():
    print(f"{param_name}: {param.shape}")

embedding.weight: torch.Size([8160, 256])
lin1.weight: torch.Size([512, 1024])
lin1.bias: torch.Size([512])
lin2.weight: torch.Size([512, 512])
lin2.bias: torch.Size([512])
lin3.weight: torch.Size([512, 512])
lin3.bias: torch.Size([512])
lin4.weight: torch.Size([512, 512])
lin4.bias: torch.Size([512])
lin_out.weight: torch.Size([8160, 512])
lin_out.bias: torch.Size([8160])


## Training Loop

In [None]:

# Move data to device
X = X.to(device)
Y = Y.to(device)

# Initialize list to store loss values
losses = []

# Before training
initial_embeddings = model.embedding.weight.detach().clone()

# Training loop
for epoch in range(1, epochs + 1):
    model.train()  # Set model to training mode
    optimizer.zero_grad()
    outputs = model(X)
    loss = loss_fn(outputs, Y)
    loss.backward()
    optimizer.step()
    
    # Record the loss
    losses.append(loss.item())
    
    # Print loss every 5 epochs or at the first epoch
    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch}/{epochs}, Loss: {loss.item():.4f}")
        
# After training
final_embeddings = model.embedding.weight.detach()
# Compare embeddings
difference = torch.abs(final_embeddings - initial_embeddings).sum().item()
print(f"Total embedding difference after training: {difference}")
# Save the model after training
model_save_path = 'models/sherlock_nextword_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

# Plot Loss vs. Epochs
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss vs. Epochs")
plt.legend()
plt.grid(True)
plt.show()


Epoch 1/500, Loss: 9.0083
Epoch 5/500, Loss: 7.3556
Epoch 10/500, Loss: 6.0455


Epoch 15/500, Loss: 5.8781
Epoch 20/500, Loss: 5.6426
Epoch 25/500, Loss: 5.5546
Epoch 30/500, Loss: 5.4682
Epoch 35/500, Loss: 5.3788
Epoch 40/500, Loss: 5.2687
Epoch 45/500, Loss: 5.1452
Epoch 50/500, Loss: 5.0056
Epoch 55/500, Loss: 4.8641
Epoch 60/500, Loss: 4.7397
Epoch 65/500, Loss: 4.6394
Epoch 70/500, Loss: 4.5315
Epoch 75/500, Loss: 4.4219
Epoch 80/500, Loss: 4.3187
Epoch 85/500, Loss: 4.2307
Epoch 90/500, Loss: 4.1537
Epoch 95/500, Loss: 4.0646
Epoch 100/500, Loss: 3.9385
Epoch 105/500, Loss: 3.8452
Epoch 110/500, Loss: 3.7708
Epoch 115/500, Loss: 3.7861
Epoch 120/500, Loss: 3.7082
Epoch 125/500, Loss: 3.5778
Epoch 130/500, Loss: 3.4593
Epoch 135/500, Loss: 3.3493
Epoch 140/500, Loss: 3.3971
Epoch 145/500, Loss: 3.1966
Epoch 150/500, Loss: 3.1330
Epoch 155/500, Loss: 3.0662
Epoch 160/500, Loss: 2.9803
Epoch 165/500, Loss: 3.1854
Epoch 170/500, Loss: 2.9619
Epoch 175/500, Loss: 2.8112
Epoch 180/500, Loss: 2.7024
Epoch 185/500, Loss: 2.6243
Epoch 190/500, Loss: 2.5774
Epoch 195

KeyboardInterrupt: 

## Loading the saved model (if not training from scratch)

In [None]:
# model = NextWord(block_size, vocab_size, embedding_dim, hidden_dim).to(device)
# model.load_state_dict(torch.load(model_save_path))
# model.eval()  # Set to evaluation mode if not training further

## Generating Text Sequences
We generate sample text sequences using the trained model.

In [None]:
sherlock_contexts = [
    "The mysterious case",
    "A peculiar",
    "Evidence",
    "The silent witness",
    "An unexpected clue",
    "A shadowy figure crept upon",
    "The hidden truth was",
    "In the dimly",
    "A cryptic message",
    "The final deduction"
]

# Iterate over each context and generate sequences
for idx, context in enumerate(sherlock_contexts, 1):
    # Generate a sequence of 20 words based on the context
    generated_sequence = generate_sequence(
        model=model,
        itos=itos,
        stoi=stoi,
        context_words=context.split(),
        block_size=block_size,
        max_len=50
    )
    
    # Print the results
    print(f"--- Sequence {idx} ---")
    print(f"Context: {context}")
    print(f"Generated sequence: {generated_sequence}\n")

--- Sequence 1 ---
Context: The mysterious case
Generated sequence: The mysterious case, as he can hardly be got into the inspector? par end par start author arthur conan doyle par end par start was there, said he, of him at the pool of the late elias whitney to identify in formats instead of speech from his pocket and

--- Sequence 2 ---
Context: A peculiar
Generated sequence: A peculiar no carpets. i had then the cover of my own with the door, on it will, her up and remanded into the long swash of the sea waves, and seen there is still now, and it that my fears is it possible it to the

--- Sequence 3 ---
Context: Evidence
Generated sequence: Evidence round to the church first. par end par start it only in the upper of the season to him a very strong piece of evidence was very strong language her point, was the love of the league of the red to the alpha. in the road -

--- Sequence 4 ---
Context: The silent witness
Generated sequence: The silent witness back on a few minutes. it s a 

## t-SNE Visualization of Trained Embeddings
Finally, we visualize the trained word embeddings using t-SNE with Plotly.

In [None]:
# %%
# Extract trained embeddings
trained_embeddings = model.embedding.weight.detach().cpu().numpy()

# Perform t-SNE
tsne = TSNE(n_components=2, random_state=42)
trained_embeddings_2d = tsne.fit_transform(trained_embeddings)

# Create DataFrame for Plotly
trained_df = pd.DataFrame({
    'word': list(stoi.keys()),
    'x': trained_embeddings_2d[:len(stoi), 0],
    'y': trained_embeddings_2d[:len(stoi), 1]
})

# Plot using Plotly Express without displaying word labels
fig = px.scatter(
    trained_df, 
    x='x', 
    y='y', 
    # Remove the 'text' parameter to hide word labels
    title='t-SNE Visualization of Trained Word Embeddings',
    hover_name='word',  # Set the word to appear on hover
    hover_data={'x': False, 'y': False},  # Exclude x and y from hover
    width=800,
    height=800
)

# Customize the plot by removing text-related trace settings
fig.update_traces(
    marker=dict(size=5)
    # Removed textposition and textfont since text labels are not displayed
)
fig.update_layout(
    title=dict(x=0.5),
    xaxis_title="t-SNE Dimension 1",
    yaxis_title="t-SNE Dimension 2",
    template='plotly_white'
)

fig.show()


## t-SNE plot in 3D

In [None]:
# %%
from sklearn.manifold import TSNE
import pandas as pd
import plotly.express as px
import torch
import re

# Ensure that 'stoi' and 'itos' are defined
if 'stoi' not in globals() or 'itos' not in globals():
    raise NameError("'stoi' and/or 'itos' are not defined. Please run the data preparation cells first.")

# Verify that the model has an embedding layer
if not hasattr(model, 'embedding'):
    raise AttributeError("The model does not have an 'embedding' attribute.")

# Extract trained embeddings from the model
trained_embeddings = model.embedding.weight.detach().cpu().numpy()
print(f"Shape of trained_embeddings: {trained_embeddings.shape}")
print(f"Number of words in 'stoi': {len(stoi)}")

# Assuming index 0 is reserved for padding, exclude it
# Ensure that trained_embeddings has len(stoi) + 1 embeddings
expected_embeddings = len(stoi) + 1  # +1 for padding
actual_embeddings = trained_embeddings.shape[0]

if actual_embeddings != expected_embeddings:
    raise ValueError(f"Mismatch between embeddings ({actual_embeddings}) and vocabulary size + padding ({expected_embeddings}).")

# Perform t-SNE in 3D
tsne = TSNE(n_components=3, random_state=42)
trained_embeddings_3d = tsne.fit_transform(trained_embeddings)

# Create a DataFrame for Plotly
# Exclude the first embedding (index 0) which is for padding
words = list(stoi.keys())
if actual_embeddings - 1 != len(words):
    raise ValueError(f"After excluding padding, number of words ({len(words)}) does not match embeddings ({actual_embeddings - 1}).")

df = pd.DataFrame({
    'word': words,
    'x': trained_embeddings_3d[1:, 0],  # Skip index 0
    'y': trained_embeddings_3d[1:, 1],
    'z': trained_embeddings_3d[1:, 2]
})

# Plot in 3D using Plotly Express
fig = px.scatter_3d(
    df, 
    x='x', 
    y='y', 
    z='z', 
    hover_name='word',  # Show word on hover
    hover_data={'x': False, 'y': False, 'z': False},  # Hide coordinates on hover
    title='3D t-SNE Visualization of Trained Word Embeddings',
    width=800,
    height=800
)

# Customize the plot
fig.update_traces(
    marker=dict(size=5, color='blue', opacity=0.7)
)
fig.update_layout(
    title=dict(x=0.5, y=0.95, xanchor='center', yanchor='top'),
    scene=dict(
        xaxis_title="t-SNE Dimension 1",
        yaxis_title="t-SNE Dimension 2",
        zaxis_title="t-SNE Dimension 3"
    ),
    template='plotly_white'
)

fig.show()


Shape of trained_embeddings: (8160, 16)
Number of words in 'stoi': 8159
