In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import numpy as np
import json
import os
from torch.nn.utils.rnn import pad_sequence

# Set random seeds for reproducibility
pl.seed_everything(42)

# ==========================================
# 1. Configuration & Loading Resources
# ==========================================

class Config:
    # Files provided by user
    EMBED_PATH = 'embedding_matrix.npz'
    W2I_PATH = 'word2id.json'
    I2W_PATH = 'id2word.json'
    DATA_PATH = 'processed_reddit.txt'
    
    # Model Hyperparameters
    HIDDEN_DIM = 256
    NUM_LAYERS = 2     # Stacked GRU
    DROPOUT = 0.3
    BATCH_SIZE = 256
    LEARNING_RATE = 1e-3
    MAX_LEN = 30       # Max sequence length for sanity (optional clipping)

print("Loading vocab and embeddings...")

# 1. Load Mappings
with open(Config.W2I_PATH, 'r') as f:
    raw_w2i = json.load(f)
    # CHANGE IS HERE: Explicitly convert values to int
    # This handles cases where json is {"word": "5"} or {"word": 5}
    word2id = {k: int(v) for k, v in raw_w2i.items()}

with open(Config.I2W_PATH, 'r') as f:
    raw_i2w = json.load(f)
    # Standard JSON keys are always strings, so we convert keys to int
    id2word = {int(k): v for k, v in raw_i2w.items()}

# 2. Load Embeddings
# Assuming the file has a key like 'embeddings' or 'arr_0'
loaded_embeds = np.load(Config.EMBED_PATH)
# If npz contains multiple arrays, usually 'arr_0' is the default if not named
embedding_matrix = loaded_embeds['arr_0'] if 'arr_0' in loaded_embeds else loaded_embeds[list(loaded_embeds.keys())[0]]

# Convert to torch tensor
embedding_tensor = torch.from_numpy(embedding_matrix).float()

# Identify special tokens
PAD_TOKEN = '<PAD>'
UNK_TOKEN = '<UNK>'
EOS_TOKEN = '<EOS>'

PAD_ID = word2id.get(PAD_TOKEN, 0)
UNK_ID = word2id.get(UNK_TOKEN, 1)
EOS_ID = word2id.get(EOS_TOKEN, 2)

VOCAB_SIZE = len(word2id)

print(f"Vocab Size: {VOCAB_SIZE}")
print(f"Embedding Shape: {embedding_tensor.shape}")

Seed set to 42


Loading vocab and embeddings...
Vocab Size: 10003
Embedding Shape: torch.Size([10003, 300])


In [2]:
# ==========================================
# 2. Dataset & Collate Function
# ==========================================

class DailyDialogDataset(Dataset):
    def __init__(self, file_path, w2i, max_len=50):
        self.sentences = []
        self.w2i = w2i
        self.max_len = max_len
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    self.sentences.append(line)

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        text = self.sentences[idx]
        
        # Simple tokenization by space (adjust if your corpus needs complex tokenization)
        words = text.split()
        
        # Convert to IDs, using UNK_ID if word not found
        token_ids = [self.w2i.get(w, UNK_ID) for w in words]
        
        # Append EOS
        token_ids.append(EOS_ID)
        
        # Optional: Clip to max length
        if len(token_ids) > self.max_len:
            token_ids = token_ids[:self.max_len]
            
        return torch.tensor(token_ids, dtype=torch.long)

def completion_collate_fn(batch):
    """
    Custom collate function to handle variable lengths and padding.
    Prepares Input (X) and Target (Y).
    """
    # Pad sequences to the longest in the batch
    # batch is a list of tensors [seq_len]
    padded_batch = pad_sequence(batch, batch_first=True, padding_value=PAD_ID)
    
    # Create Inputs and Targets for Next-Word Prediction
    # Input:  [w1, w2, w3]
    # Target: [w2, w3, EOS]
    
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]
    
    # Create a clone for loss calculation
    loss_targets = targets.clone()
    
    # LOGIC: If target is <UNK>, skip it (set to ignore_index -100)
    # LOGIC: If target is <PAD>, skip it (set to ignore_index -100)
    loss_targets[loss_targets == UNK_ID] = -100
    loss_targets[loss_targets == PAD_ID] = -100
    
    return inputs, loss_targets

In [3]:
# ==========================================
# 3. Model Architecture (Stacked GRU + Attention)
# ==========================================

class SelfAttention(nn.Module):
    """
    A simple dot-product self-attention mechanism.
    """
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, encoder_outputs):
        # encoder_outputs: [Batch, Seq_Len, Hidden_Dim]
        
        # Calculate energy
        # energy: [Batch, Seq_Len, 1]
        energy = self.projection(encoder_outputs)
        
        # Calculate weights
        weights = F.softmax(energy.squeeze(-1), dim=1) # [Batch, Seq_Len]
        
        # Apply weights to outputs to get context vector
        # weights.unsqueeze(1): [Batch, 1, Seq_Len]
        # encoder_outputs: [Batch, Seq_Len, Hidden_Dim]
        # bmm result: [Batch, 1, Hidden_Dim]
        context = torch.bmm(weights.unsqueeze(1), encoder_outputs)
        
        return context.squeeze(1), weights

class CompletionModel(pl.LightningModule):
    def __init__(self, embedding_tensor, hidden_dim, num_layers, dropout, lr):
        super().__init__()
        self.save_hyperparameters(ignore=['embedding_tensor'])
        
        self.vocab_size = embedding_tensor.size(0)
        self.embed_dim = embedding_tensor.size(1)
        
        # 1. Embedding
        self.embedding = nn.Embedding.from_pretrained(
            embedding_tensor, 
            freeze=False, 
            padding_idx=PAD_ID
        )
        
        # 2. Stacked GRU
        self.gru = nn.GRU(
            input_size=self.embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # 3. Attention & Output
        self.attention = nn.Linear(hidden_dim, hidden_dim) 
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.fc = nn.Linear(hidden_dim, self.vocab_size)
        
        self.lr = lr
        
        # Loss function (ignores -100)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, x):
        embedded = self.embedding(x)
        gru_out, _ = self.gru(embedded)
        attn_out = torch.tanh(self.attention(gru_out))
        out = self.layer_norm(gru_out + attn_out)
        logits = self.fc(out)
        return logits

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        logits = self(inputs) 
        
        # Flatten
        loss = self.criterion(logits.view(-1, self.vocab_size), targets.view(-1))
        
        # Logs training loss per step
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        logits = self(inputs) # [Batch, Seq, Vocab]
        
        # 1. Calculate Validation Loss
        loss = self.criterion(logits.view(-1, self.vocab_size), targets.view(-1))
        
        # 2. Calculate Validation Accuracy
        # Get the predicted token ID
        preds = torch.argmax(logits, dim=-1) # [Batch, Seq]
        
        # Create a mask of valid targets (tokens that are NOT -100)
        # We did targets[targets == UNK_ID] = -100 in the collate_fn
        mask = (targets != -100) 
        
        # Count how many predictions matched the targets, strictly where mask is True
        correct_predictions = ((preds == targets) & mask).sum()
        
        # Total number of valid tokens (excluding PAD and UNK)
        total_valid_tokens = mask.sum()
        
        # Avoid division by zero
        if total_valid_tokens > 0:
            accuracy = correct_predictions.float() / total_valid_tokens.float()
        else:
            # Ensure the 0.0 is on the same device as the model/targets
            accuracy = torch.tensor(0.0, device=self.device)

        # Log metrics (on_epoch=True ensures it averages over the whole val set)
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', accuracy, prog_bar=True, on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [4]:
# ==========================================
# 4. Training
# ==========================================

# Prepare Data
full_dataset = DailyDialogDataset(Config.DATA_PATH, word2id)

# Split 90% Train, 10% Val
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Train Loader
train_loader = DataLoader(
    train_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=True, 
    collate_fn=completion_collate_fn,
    num_workers=4
)

# Validation Loader (Shuffle=False usually)
val_loader = DataLoader(
    val_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=False, 
    collate_fn=completion_collate_fn,
    num_workers=4
)

# Initialize Model
model = CompletionModel(
    embedding_tensor=embedding_tensor,
    hidden_dim=Config.HIDDEN_DIM,
    num_layers=Config.NUM_LAYERS,
    dropout=Config.DROPOUT,
    lr=Config.LEARNING_RATE
)

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=15,
    accelerator='auto',
    devices=1,
    enable_progress_bar=True,
    # Optional: Stop early if validation loss doesn't improve
    callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=2)]
)



ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


Training samples: 2886695
Validation samples: 320744


In [8]:
checkpoint = torch.load("lightning_logs/version_6/checkpoints/epoch=9-step=112770.ckpt")
model = CompletionModel(
    embedding_tensor=embedding_tensor,
    hidden_dim=Config.HIDDEN_DIM,
    num_layers=Config.NUM_LAYERS,
    dropout=Config.DROPOUT,
    lr=Config.LEARNING_RATE
)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [None]:
# Start Training with Validation
print("Starting Training...")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [15]:
def complete_sentence_robust(model, start_text, max_new_words=10):
    model.eval()
    
    # 1. Detect the device of the model parameters (CPU or CUDA)
    device = next(model.parameters()).device
    
    words = start_text.split()
    # Use UNK_ID if word is missing
    current_ids = [word2id.get(w, word2id.get('<UNK>')) for w in words]
    
    # 2. Create tensor and EXPLICITLY move it to the detected device
    current_tensor = torch.tensor([current_ids], dtype=torch.long).to(device)
    
    generated_words = []
    
    with torch.no_grad():
        for _ in range(max_new_words):
            logits = model(current_tensor)
            
            # Get logits of the last token
            last_token_logits = logits[0, -1, :]
            
            # Greedy prediction
            predicted_id = torch.argmax(last_token_logits).item()
            
            # Stop at EOS
            if predicted_id == word2id.get('<EOS>'):
                break
                
            predicted_word = id2word.get(predicted_id, '<UNK>')
            generated_words.append(predicted_word)
            
            # 3. Create next input and move to SAME device
            next_input = torch.tensor([[predicted_id]], dtype=torch.long).to(device)
            current_tensor = torch.cat([current_tensor, next_input], dim=1)
            
    return start_text + " " + " ".join(generated_words)

# Run the sanity check again
print("Running Robust Sanity Check...")
print(complete_sentence_robust(model, "how are you"))

print("\n=== Sanity Check: Auto-completion ===")
test_sentences = [
    "how are you",
    "what is your name",
    "i would like to buy",
    "the weather is very",
    "can you help me",
    "my name ",
    "this might"
]

# Move model to CPU for quick inference check if it was on GPU
model.to('cpu')

for sent in test_sentences:
    completion = complete_sentence_robust(model, sent)
    print(f"Input: '{sent}'  ->  Output: '{completion}'")

Running Robust Sanity Check...
how are you supposed to be in a relationship

=== Sanity Check: Auto-completion ===
Input: 'how are you'  ->  Output: 'how are you supposed to be in a relationship'
Input: 'what is your name'  ->  Output: 'what is your name '
Input: 'i would like to buy'  ->  Output: 'i would like to buy a house and have a good time'
Input: 'the weather is very'  ->  Output: 'the weather is very cold and i have a lot of fun'
Input: 'can you help me'  ->  Output: 'can you help me '
Input: 'my name '  ->  Output: 'my name  is a little bit of a question'
Input: 'this might'  ->  Output: 'this might be a bit of a rant but i don't know'


In [11]:
# ==========================================
# INTERACTIVE DEMO CELL
# ==========================================

import torch

def interactive_completion(model, w2i, i2w, device, max_new_words=15):
    """
    Runs an interactive loop for auto-completion.
    """
    model.eval()
    model.to(device)
    
    print("-" * 50)
    print("Interactive Auto-Completion Mode")
    print("Type a start of a sentence (e.g., 'how are') and hit Enter.")
    print("Type 'quit', 'exit', or 'q' to stop.")
    print("-" * 50)

    while True:
        try:
            # Get user input
            user_text = input("\nInput: ").strip().lower()
            
            # Check for exit command
            if user_text in ['quit', 'exit', 'q']:
                print("Exiting...")
                break
            
            if not user_text:
                continue

            # Tokenize and convert to IDs
            words = user_text.split()
            current_ids = [w2i.get(w, w2i.get('<UNK>')) for w in words]
            
            # Convert to tensor [1, seq_len]
            current_tensor = torch.tensor([current_ids], dtype=torch.long).to(device)
            
            generated_words = []
            
            # Generation Loop
            with torch.no_grad():
                for _ in range(max_new_words):
                    # Forward pass
                    logits = model(current_tensor)
                    
                    # Get the logits for the last step: [Batch, Seq, Vocab] -> [Vocab]
                    next_token_logits = logits[0, -1, :]
                    
                    # Greedy decoding (Pick best prob)
                    # For more variety, use: predicted_id = torch.multinomial(F.softmax(next_token_logits, dim=-1), 1).item()
                    predicted_id = torch.argmax(next_token_logits).item()
                    
                    # Stop if EOS
                    if predicted_id == w2i.get('<EOS>'):
                        break
                    
                    # Decode word
                    predicted_word = i2w.get(predicted_id, '<UNK>')
                    generated_words.append(predicted_word)
                    
                    # Prepare input for next step (feed back the predicted token)
                    next_input = torch.tensor([[predicted_id]], dtype=torch.long).to(device)
                    current_tensor = torch.cat([current_tensor, next_input], dim=1)
            
            # Print result
            full_sentence = user_text + " " + " ".join(generated_words)
            print(f"user: {user_text} -> Model: {full_sentence}")
            
        except KeyboardInterrupt:
            print("\nInterrupted by user.")
            break
        except Exception as e:
            print(f"Error: {e}")

# Run the interactive session
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
interactive_completion(model, word2id, id2word, device)

--------------------------------------------------
Interactive Auto-Completion Mode
Type a start of a sentence (e.g., 'how are') and hit Enter.
Type 'quit', 'exit', or 'q' to stop.
--------------------------------------------------


user: i am -> Model: i am a year old male and i am a year old male
user: i am an -> Model: i am an atheist and i am a very religious person
user: can we -> Model: can we be able to get back together
user: are you in -> Model: are you in the wrong
user: people are -> Model: people are not the same
user: i love -> Model: i love her and i don't want to lose her
user: i want to buy -> Model: i want to buy a house and have a good time
user: this is nice -> Model: this is nice 
user: can we talk -> Model: can we talk about it
user: i hate -> Model: i hate it
user: honestly -> Model: honestly i don't know what to do
user: finally -> Model: finally i was able to get my head back and i was in the middle of
Exiting...


In [13]:
# ==========================================
# EXPORT TO ONNX
# ==========================================
import torch

# 1. Put model in eval mode
model.eval()

# 2. Create a dummy input (Batch size 1, Sequence length 5)
# The values don't matter, just the shape and type (Long/Int64)
dummy_input = torch.randint(0, 100, (1, 5), dtype=torch.long).to(device)

# 3. Define path
onnx_file_path = "completion_model_reddit.onnx"

# 4. Export
# dynamic_axes is CRITICAL: it tells ONNX that the sequence length (dim 1) 
# can change, so you can feed it 2 words or 20 words.
torch.onnx.export(
    model,                      # Model instance
    dummy_input,                # Dummy input
    onnx_file_path,             # Output file
    export_params=True,         # Store the trained parameter weights
    opset_version=12,           # Standard opset (11 or 12 works best for GRU/Attention)
    do_constant_folding=True,   # Optimization
    input_names=['input_ids'],  # Name of input layer
    output_names=['logits'],    # Name of output layer
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'logits':    {0: 'batch_size', 1: 'sequence_length'}
    }
)

print(f"Model successfully exported to {onnx_file_path}")

Model successfully exported to completion_model_reddit.onnx


