In [1]:
import json
import numpy as np
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Configuration
CONFIG = {
    'vocab_path': 'word2id.json',
    'id_path': 'id2word.json',
    'emb_path': 'embedding_matrix_daily.npz',
    'batch_size': 512,
    'hidden_dim': 512,       # Reverted to 256 as requested
    'num_layers': 3,         # IMPROVEMENT: Increased layers for depth
    'learning_rate': 0.001,
    'max_seq_len': 30, 
    'pad_token_id': 0,
    'dropout': 0.3,
    'unk_token_id': 0        # Placeholder, will update after loading vocab
}

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

Running on: cuda


In [2]:
# Load Vocabulary
with open(CONFIG['vocab_path'], 'r') as f:
    word2id = json.load(f)

with open(CONFIG['id_path'], 'r') as f:
    id2word = json.load(f)

# Update UNK ID in config
CONFIG['unk_token_id'] = int(word2id.get('<UNK>', 0))
print(f"UNK ID set to: {CONFIG['unk_token_id']}")

# Load Embedding Matrix
emb_data = np.load(CONFIG['emb_path'])
embedding_matrix = emb_data[list(emb_data.keys())[0]] 

vocab_size, embed_dim = embedding_matrix.shape
print(f"Vocab Size: {vocab_size}, Embedding Dim: {embed_dim}")

# Convert to Tensor (Float32 for safety, Mixed Precision trainer will handle casting)
embedding_tensor = torch.tensor(embedding_matrix, dtype=torch.float32)

UNK ID set to: 2
Vocab Size: 20003, Embedding Dim: 300


In [3]:
class TextPredictionDataset(Dataset):
    def __init__(self, encoded_sentences, unk_id, max_len=20):
        """
        Args:
            encoded_sentences: List of list of token IDs
            unk_id: ID of the <UNK> token
            max_len: Max context window
        """
        self.samples = []
        
        for sentence in encoded_sentences:
            if len(sentence) < 2:
                continue
                
            # Create N-grams
            for i in range(1, len(sentence)):
                input_seq = sentence[:i]
                target_token = sentence[i]
                
                # IMPROVEMENT: Skip sample if target is UNK
                # This prevents the model from learning to predict <UNK>
                if target_token == unk_id:
                    continue
                
                if len(input_seq) > max_len:
                    input_seq = input_seq[-max_len:]
                
                self.samples.append((input_seq, target_token))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        input_seq, target = self.samples[idx]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target, dtype=torch.long)

def collate_fn(batch):
    inputs, targets = zip(*batch)
    padded_inputs = torch.nn.utils.rnn.pad_sequence(
        inputs, batch_first=True, padding_value=CONFIG['pad_token_id']
    )
    targets = torch.stack(targets)
    return padded_inputs, targets

In [4]:
# Restoring Attention Mechanism
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        
    def forward(self, hidden, encoder_outputs, mask=None):
        # hidden: [batch, hidden_dim] (The final state of the GRU)
        # encoder_outputs: [batch, seq_len, hidden_dim] (All states of the GRU)
        
        src_len = encoder_outputs.shape[1]
        
        # Repeat hidden state src_len times
        hidden_expanded = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        # Calculate energy
        energy = torch.tanh(self.attn(hidden_expanded + encoder_outputs))
        
        # Calculate attention scores
        attention = self.v(energy).squeeze(2)
        
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e4)
            
        return F.softmax(attention, dim=1)

In [5]:
class NextWordGRU(pl.LightningModule):
    def __init__(self, embedding_matrix, hidden_dim, vocab_size, lr, pad_idx):
        super().__init__()
        self.save_hyperparameters(ignore=['embedding_matrix'])
        
        # 1. Embedding Layer
        # FROZEN as requested because they are FastText semantics
        self.embedding = nn.Embedding.from_pretrained(
            embedding_matrix, 
            freeze=True, 
            padding_idx=pad_idx
        )
        
        # 2. GRU Layer (Reverted to GRU)
        self.gru = nn.GRU(
            input_size=embedding_matrix.shape[1], 
            hidden_size=hidden_dim,
            num_layers=CONFIG['num_layers'], # 3 Layers
            batch_first=True,
            dropout=CONFIG['dropout'] if CONFIG['num_layers'] > 1 else 0
        )
        
        # 3. Attention Layer
        self.attention = Attention(hidden_dim)
        
        # 4. Dense Output
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)
        
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
        self.lr = lr

    def forward(self, x):
        mask = (x != self.hparams.pad_idx)
        embedded = self.embedding(x)
        
        # GRU Output
        outputs, hidden = self.gru(embedded)
        
        # Take the hidden state of the LAST layer
        final_hidden = hidden[-1] 
        
        # Calculate Attention
        attn_weights = self.attention(final_hidden, outputs, mask)
        context = torch.bmm(attn_weights.unsqueeze(1), outputs).squeeze(1)
        
        # Combine Context and Hidden
        combined = torch.cat((context, final_hidden), dim=1)
        logits = self.fc(combined)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        predictions = torch.argmax(logits, dim=1)
        accuracy = (predictions == y).float().mean()
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        self.log('val_accuracy', accuracy, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=1, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

In [6]:
def preprocess_word(word):
    word = word.lower()
    word = re.sub(r'[^\w]+$', '', word)
    return word

In [7]:


def load_and_encode_corpus(filename, word2id, unk_token='<UNK>', eos_token='<EOS>'):
    encoded_corpus = []
    unk_id = int(word2id.get(unk_token, 0)) 
    eos_id = int(word2id.get(eos_token, 2))
    
    print(f"Reading {filename}...")
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            text = line.strip()
            if not text: continue
            words = text.split()
            encoded_sent = []
            for w in words:
                w = preprocess_word(w)
                if not w: continue
                val = word2id.get(w, unk_id)
                encoded_sent.append(int(val))
            if encoded_sent:
                encoded_sent.append(eos_id)
                encoded_corpus.append(encoded_sent)
    print(f"Loaded {len(encoded_corpus)} sentences.")
    return encoded_corpus

# Ensure word2id values are ints
word2id = {k: int(v) for k, v in word2id.items()}

corpus_ints = load_and_encode_corpus(
    'compined.txt', 
    word2id, 
    unk_token='<UNK>', 
    eos_token='<EOS>'
)

split_idx = int(len(corpus_ints) * 0.8)
train_data = corpus_ints[:split_idx]
val_data = corpus_ints[split_idx:]

# --- DATASETS WITH UNK SKIPPING ---
train_dataset = TextPredictionDataset(
    train_data, 
    unk_id=CONFIG['unk_token_id'], 
    max_len=CONFIG['max_seq_len']
)
val_dataset = TextPredictionDataset(
    val_data, 
    unk_id=CONFIG['unk_token_id'], 
    max_len=CONFIG['max_seq_len']
)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=0)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss', dirpath='./checkpoints', filename='daily-{epoch:02d}-{val_loss:.2f}',
    save_top_k=2, mode='min'
)

early_stopping_callback = EarlyStopping(monitor='val_loss', mode='min', patience=3, verbose=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# --- MODEL & TRAINING ---
model = NextWordGRU(
    embedding_matrix=embedding_tensor,
    hidden_dim=CONFIG['hidden_dim'],
    vocab_size=vocab_size,
    lr=CONFIG['learning_rate'],
    pad_idx=CONFIG['pad_token_id']
)

trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
    accelerator="auto",
    devices=1,
    precision='16-mixed', # ENABLE MIXED PRECISION
    log_every_n_steps=10
)



Reading compined.txt...
Loaded 380481 sentences.


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [8]:
trainer.fit(model, train_loader, val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/hatem/.virtualenvs/ml/lib/python3.13/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:881: Checkpoint directory /home/hatem/Development/python/ml/autocomplete/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/hatem/.virtualenvs/ml/lib/python3.13/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name      | Type             | Params | Mode  | FLOPs
---------------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/hatem/.virtualenvs/ml/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/hatem/.virtualenvs/ml/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 5.242


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.184 >= min_delta = 0.0. New best score: 5.058


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.076 >= min_delta = 0.0. New best score: 4.982


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 4.961


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 4.960


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 3 records. Best score: 4.960. Signaling Trainer to stop.


In [9]:
# --- PREPARE DATA FOR EVALUATION (WITHOUT TRAINING) ---
# Re-use existing 'train_data' and 'val_data' lists from previous cells
# Just create datasets without running trainer

split_idx = int(len(corpus_ints) * 0.8)
train_data = corpus_ints[:split_idx]
val_data = corpus_ints[split_idx:]

train_dataset = TextPredictionDataset(
    train_data, 
    unk_id=CONFIG['unk_token_id'], 
    max_len=CONFIG['max_seq_len']
)
val_dataset = TextPredictionDataset(
    val_data, 
    unk_id=CONFIG['unk_token_id'], 
    max_len=CONFIG['max_seq_len']
)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")


Train dataset: 3104585 samples
Val dataset: 840192 samples


In [10]:
# --- 1. DATASET FOR MULTI-WORD PREDICTION ---
class MultiWordDataset(Dataset):
    def __init__(self, encoded_sentences, unk_id, pad_id, max_len=20, pred_len=4):
        self.samples = []
        self.pred_len = pred_len
        self.pad_id = pad_id
        
        for sentence in encoded_sentences:
            if len(sentence) < 2:
                continue
            
            # Create samples
            # We stop earlier so we don't start a sample at the very end
            for i in range(1, len(sentence)):
                input_seq = sentence[:i]
                
                # Get the next 'pred_len' tokens as target
                target_seq = sentence[i : i + pred_len]
                
                # If target is UNK, we might want to skip, but for seqs it's complex.
                # Let's just keep them but rely on the model to learn.
                
                # Pad target sequence if it hits EOS early
                if len(target_seq) < pred_len:
                    padding = [pad_id] * (pred_len - len(target_seq))
                    target_seq = target_seq + padding
                
                if len(input_seq) > max_len:
                    input_seq = input_seq[-max_len:]
                
                self.samples.append((input_seq, target_seq))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        input_seq, target_seq = self.samples[idx]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)

def multi_collate_fn(batch):
    inputs, targets = zip(*batch)
    # Pad inputs (variable length)
    padded_inputs = torch.nn.utils.rnn.pad_sequence(
        inputs, batch_first=True, padding_value=CONFIG['pad_token_id']
    )
    # Targets are fixed length (4), so we just stack them
    targets = torch.stack(targets) 
    return padded_inputs, targets

# --- 2. PREPARE NEW DATALOADERS ---
# Re-use existing 'train_data' and 'val_data' lists from previous cells
ft_train_dataset = MultiWordDataset(train_data, unk_id=CONFIG['unk_token_id'], pad_id=CONFIG['pad_token_id'])
ft_val_dataset = MultiWordDataset(val_data, unk_id=CONFIG['unk_token_id'], pad_id=CONFIG['pad_token_id'])

ft_train_loader = DataLoader(ft_train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=multi_collate_fn, num_workers=0)
ft_val_loader = DataLoader(ft_val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=multi_collate_fn, num_workers=0)


# --- 3. FINE-TUNING MODEL WRAPPER ---
class MultiStepFineTuner(NextWordGRU):
    def training_step(self, batch, batch_idx):
        src, targets = batch # src: [Batch, Seq], targets: [Batch, 4]
        
        # --- Step 1: Initial Encode (Identical to Forward) ---
        embedded = self.embedding(src)
        encoder_outputs, hidden = self.gru(embedded) # hidden is (Layers, Batch, Dim)
        
        loss = 0
        batch_size = src.size(0)
        
        # Current hidden state for the Attention/Prediction layer (Top layer of stack)
        current_hidden = hidden[-1]
        
        # --- Step 2: Decode Loop (4 Steps) ---
        # We use the previous hidden state to predict, then feed the TRUE target (Teacher Forcing)
        # to generate the next hidden state.
        
        mask = (src != self.hparams.pad_idx)
        
        for i in range(targets.size(1)):
            # A. PREDICTION
            # Attend to the ORIGINAL Context (Seq2Seq style)
            attn_weights = self.attention(current_hidden, encoder_outputs, mask)
            context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
            combined = torch.cat((context, current_hidden), dim=1)
            logits = self.fc(combined)
            
            # B. LOSS CALCULATION
            step_target = targets[:, i]
            step_loss = self.loss_fn(logits, step_target)
            loss += step_loss
            
            # C. PREPARE NEXT STEP (Teacher Forcing)
            # Feed the TRUE current token into the GRU to get state for next prediction
            # Embed the target token: [Batch] -> [Batch, 1, Emb]
            inp_next = self.embedding(step_target).unsqueeze(1)
            
            # Update Hidden State
            # We feed the FULL hidden stack (layers, batch, dim) back into GRU
            _, hidden = self.gru(inp_next, hidden)
            current_hidden = hidden[-1]
            
        # Average loss over the steps
        final_loss = loss / targets.size(1)
        self.log('train_loss', final_loss, prog_bar=True)
        return final_loss

    def validation_step(self, batch, batch_idx):
        # Simplified validation: just check loss on sequences
        src, targets = batch
        embedded = self.embedding(src)
        encoder_outputs, hidden = self.gru(embedded)
        loss = 0
        current_hidden = hidden[-1]
        mask = (src != self.hparams.pad_idx)
        
        for i in range(targets.size(1)):
            attn_weights = self.attention(current_hidden, encoder_outputs, mask)
            context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
            combined = torch.cat((context, current_hidden), dim=1)
            logits = self.fc(combined)
            loss += self.loss_fn(logits, targets[:, i])
            
            inp_next = self.embedding(targets[:, i]).unsqueeze(1)
            _, hidden = self.gru(inp_next, hidden)
            current_hidden = hidden[-1]
            
        self.log('val_loss', loss / targets.size(1), prog_bar=True)

# --- 4. LOAD & TRAIN ---
# Load weights from your best previous checkpoint
# Replace 'checkpoints/gru-attn-epoch=03...' with your actual best checkpoint path
prev_ckpt = "./checkpoints/daily-epoch=07-val_loss=4.44.ckpt"

print("Initializing Fine-Tuner...")
finetuner = MultiStepFineTuner.load_from_checkpoint(
    prev_ckpt,
    embedding_matrix=embedding_tensor,
    strict=False # Allow strict=False in case of minor internal attribute diffs
)

# Create a new trainer for fine-tuning
ft_trainer = pl.Trainer(
    max_epochs=3, # Fine-tune for just a few epochs
    accelerator="auto",
    devices=1,
    precision='16-mixed',
    callbacks=[
        ModelCheckpoint(dirpath='./checkpoints_ft', filename='finetuned-{epoch}-{val_loss_seq:.2f}'),
        EarlyStopping(monitor='val_loss', patience=2)
    ]
)

print("Starting Fine-Tuning (Sequence Length 4)...")
ft_trainer.fit(finetuner, ft_train_loader, ft_val_loader)

Initializing Fine-Tuner...


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/home/hatem/.virtualenvs/ml/lib/python3.13/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:881: Checkpoint directory /home/hatem/Development/python/ml/autocomplete/checkpoints_ft exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode  | FLOPs
---------------------------------------------------------------
0 | embedding | Embedding        | 6.0 M  | train | 0    
1 | gru       | GRU              | 4.4 M  | train | 0    
2 | attention | Attention        | 263 K  | train | 0    
3 | fc        | Linear           | 20.5 M | train | 0    
4 | loss_fn   | CrossEntropyLoss | 0      | train | 0    
---------------------------------------------------------------
25.2 M    Trainable params
6.0 M     Non-trainable params
31.2 M    Total params
124.677   Total estimated model params size (MB)
7        

Starting Fine-Tuning (Sequence Length 4)...


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [11]:
# --- SANITY CHECK ---
def predict_next_words(model, partial_sentence, word2id, id2word, top_k=5, max_len=20):
    model.eval()
    device = next(model.parameters()).device
    
    words = partial_sentence.split()
    encoded = []
    unk_id = int(word2id.get('<UNK>', 0))
    
    for w in words:
        w = preprocess_word(w)
        if not w: continue
        encoded.append(int(word2id.get(w, unk_id)))
    
    if len(encoded) > max_len:
        encoded = encoded[-max_len:]
    
    if not encoded:
        return []

    input_tensor = torch.tensor([encoded], dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits = model(input_tensor)
    
    probs = F.softmax(logits[0], dim=0)
    
    # Optional: Manually zero out UNK probability if you want strictly no UNK predictions
    probs[unk_id] = 0.0
    
    top_probs, top_indices = torch.topk(probs, top_k)
    
    results = []
    for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
        word = id2word.get(str(int(idx)), '<UNK>')
        results.append((word, float(prob)))
    
    return results

test_sentences = [
    'how are you',
    'Nice meeting you',
    'this is a really',
    'What is '
]

print("=" * 60)
print("SANITY CHECK: Model Predictions (Skipping UNK)")
print("=" * 60)

model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

for sentence in test_sentences:
    predictions = predict_next_words(model, sentence, word2id, id2word, top_k=5)
    print(f"\nInput: '{sentence}'")
    print("Top 5 Predicted Next Words:")
    for i, (word, prob) in enumerate(predictions, 1):
        print(f"  {i}. {word:<20} ({prob:.4f})")

print("=" * 60)

SANITY CHECK: Model Predictions (Skipping UNK)

Input: 'how are you'
Top 5 Predicted Next Words:
  1. doing                (0.1802)
  2. <EOS>                (0.0459)
  3. feeling              (0.0385)
  4. communicating        (0.0271)
  5. going                (0.0251)

Input: 'Nice meeting you'
Top 5 Predicted Next Words:
  1. cledus               (0.1341)
  2. <EOS>                (0.0669)
  3. tearing              (0.0636)
  4. taransky             (0.0158)
  5. and                  (0.0144)

Input: 'this is a really'
Top 5 Predicted Next Words:
  1. staircase            (0.0349)
  2. <EOS>                (0.0229)
  3. barbaric             (0.0223)
  4. mistake              (0.0204)
  5. thing                (0.0204)

Input: 'What is '
Top 5 Predicted Next Words:
  1. the                  (0.0797)
  2. that                 (0.0593)
  3. a                    (0.0447)
  4. you                  (0.0375)
  5. this                 (0.0341)


In [15]:
# --- Load Checkpoint and Interactive Generation ---

# 1. Path to your checkpoint
# Make sure the path matches where the file is actually located.
# Based on the previous cell, it's likely inside the 'checkpoints' folder.
ckpt_path = "./checkpoints/daily-epoch=04-val_loss=4.96.ckpt"

print(f"Loading model from {ckpt_path}...")

# 2. Load the Model
# We must pass 'embedding_matrix' because we ignored it in save_hyperparameters
loaded_model = NextWordGRU.load_from_checkpoint(
    ckpt_path,
    embedding_matrix=embedding_tensor,  # Requires embedding_tensor from previous cells
    map_location=device,
)

loaded_model.to(device)
loaded_model.eval()
print("Model loaded successfully!")


# 3. Define Autoregressive Generation Function
def generate_completion(
    model, start_text, word2id, id2word, max_generated=20, temp=1.0
):
    """
    Generates text starting from start_text until <EOS> or max_generated tokens.
    """
    model.eval()
    words = start_text.split()
    current_ids = []

    # Encode initial string
    unk_id = int(word2id.get("<UNK>", 0))
    for w in words:
        w = preprocess_word(w)
        if w:
            current_ids.append(int(word2id.get(w, unk_id)))

    input_seq = current_ids[:]  # Copy for keeping track
    generated_words = []

    with torch.no_grad():
        for _ in range(max_generated):
            # Prepare input tensor (truncate to max_seq_len if needed)
            seq_tensor = torch.tensor([input_seq[-20:]], dtype=torch.long).to(device)

            # Forward pass
            logits = model(seq_tensor)

            # Apply Temperature (Higher = more random/creative, Lower = more deterministic)
            logits = logits[0] / temp

            # Get probabilities
            probs = F.softmax(logits, dim=0)

            # Sample from the distribution (more natural) or take Argmax (more rigid)
            # Using Argmax for stability in early training, change to multinomial for creativity
            next_token_id = torch.argmax(probs).item()

            # Decode
            next_word = id2word.get(str(next_token_id), "<UNK>")

            # Stop if EOS or Unknown (optional)
            if next_word == "<EOS>":
                break

            generated_words.append(next_word)
            input_seq.append(next_token_id)

    return " ".join(generated_words)


# 4. Interactive Loop
print("\n" + "=" * 40)
print("ðŸ¤– TEXT COMPLETION BOT (Type 'exit' to stop)")
print("=" * 40)

while True:
    user_input = input("\nEnter start of sentence: ")

    if user_input.lower() in ["exit", "quit"]:
        print("Goodbye!")
        break

    if not user_input.strip():
        continue

    try:
        completion = generate_completion(
            loaded_model, user_input, word2id, id2word, max_generated=8
        )
        print(f"Model: {user_input} \033[1m{completion}\033[0m")
    except Exception as e:
        print(f"Error generating text: {e}")


Loading model from ./checkpoints/daily-epoch=04-val_loss=4.96.ckpt...
Model loaded successfully!

ðŸ¤– TEXT COMPLETION BOT (Type 'exit' to stop)
Model: who [1m[0m
Model: who are [1myou[0m
Model: what do  [1myou think[0m
Model: i love [1mthe guy[0m
Model: kiss [1m[0m
Model: kiss me [1m[0m
Model: can you [1mbe a fool[0m
Model: yo [1m[0m
Model: you [1mdon't know what you think[0m
Model: I hate [1mthe truth[0m
Model: can we [1mtalk[0m
Model: what abouy [1m[0m
Model: what about [1mthe other[0m
Model: I don't [1mknow what i mean[0m
Model: he is  [1ma man[0m
Model: she is  [1ma very very very fifty-six of a young[0m
Model: are you [1mkidding[0m
Model: we need [1mto get a lot[0m
Model: huh [1m[0m
Model: wow [1m[0m
Model: I hate the [1mtruth[0m
Model: i quit [1myou[0m
Model: what is  [1mthe name[0m
Model: what is the name [1m[0m
Model: what is the name of [1mthe name[0m
Goodbye!
