In [None]:
# For Google Colab
# Upload the folder containing this file to google drive.
import sys, os
# Checking if the notebook is opened in google colab
#If YES, mount the google drive and change the directory
if 'google.colab' in sys.modules:

    # mount google drive
    from google.colab import drive
    drive.mount('/content/drive')

    # change path to the folder
    path = 'xxxxx/xxxxx'
    print(path)
    #os.chdir changes the current working directory
    os.chdir(path)
    !pwd

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import time
from transformer_ import *
from utils_gpt_ import *

In [None]:
class GPTTextGenerator:
    def __init__(self, model_path=None, vocab_filename=None, device=None):
        """
        Initialize generator (for generating indices) along with vocabulary (for mapping indices to words)

        Args:
            model_path (str): Path to trained model checkpoint
            vocab_filename (str): Name of vocabulary text file in same directory
            device (str or torch.device or None): Device for generation; if None, auto-detect
        """

        self.device = torch.device(device) if device is not None else get_device()
        print(f"Text Generator initialized on {self.device}")

        # Load model checkpoint data (config + trained weights)
        model_data = self._load_model_data(model_path)
        # Extract model architecture configuration
        self.config = model_data['config']

        # Create model architecture with random weights based on saved config
        print('Creating GPT with random weights')
        self.model = self._create_model(self.config)

        # Overwrite random weights with trained parameters from checkpoint
        print('Loading the pretrained weights')
        self.model.load_state_dict(model_data['state_dict'])
        self.model.to(self.device)
        self.model.eval()
        print(f"Model loaded from: {model_path}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")

        # Validate vocabulary filename is provided
        if vocab_filename is None:
            raise ValueError("vocab_filename must be provided")

        # Load vocabulary mappings from text file
        self._load_vocabulary_from_text(vocab_filename)
        print(f"Vocabulary loaded from: {vocab_filename}")
        print(f"Vocabulary size: {len(self.word2idx)}")

        # Sanity check: vocab size must match checkpoint config
        if len(self.word2idx) != self.config['vocab_size']:
            raise ValueError(
                f"Vocab size mismatch: vocab.txt={len(self.word2idx)} "
                f"vs model config={self.config['vocab_size']}"
            )

    def _load_model_data(self, checkpoint_path):
        """Load model configuration and weights from checkpoint."""
        if checkpoint_path is None or not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Model checkpoint not found: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        return {
            'config': checkpoint['model_config'],      # Model architecture parameters
            'state_dict': checkpoint['model_state_dict']  # Trained model weights
        }

    def _create_model(self, config):
        """Create model architecture."""
        return GPT(**config)

    def _load_vocabulary_from_text(self, vocab_filename):
        """Load vocabulary from text file in same directory."""
        if not os.path.exists(vocab_filename):
            raise FileNotFoundError(f"Vocabulary file not found: {vocab_filename}")

        with open(vocab_filename, 'r', encoding='utf-8') as f:
            words = [line.strip() for line in f if line.strip()]

        self.word2idx = {word: idx for idx, word in enumerate(words)}
        self.idx2word = words

        self.special_tokens = {}
        if '<pad>' in self.word2idx:
            self.special_tokens['pad_id'] = self.word2idx['<pad>']
        if '<unk>' in self.word2idx:
            self.special_tokens['unk_id'] = self.word2idx['<unk>']
        if '<bos>' in self.word2idx:
            self.special_tokens['bos_id'] = self.word2idx['<bos>']
        if '<eos>' in self.word2idx:
            self.special_tokens['eos_id'] = self.word2idx['<eos>']
        if '<mask>' in self.word2idx:
            self.special_tokens['mask_id'] = self.word2idx['<mask>']

    def text_to_tokens(self, text):
        """Convert text to token IDs."""
        # Your data contains only lower case; keep lowercasing to match training
        if isinstance(text, str):
            words = text.lower().strip().split()
        else:
            words = text

        unk_fallback = self.special_tokens.get('unk_id', self.special_tokens.get('pad_id', 0))
        return [self.word2idx.get(w, unk_fallback) for w in words]

    def tokens_to_text(self, tokens):
        """Convert token IDs to text."""
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.tolist()

        words_out = []
        for token in tokens:
            if 0 <= token < len(self.idx2word):
                w = self.idx2word[token]
                # Show [UNK] for unknowns; skip other specials
                if w == '<unk>':
                    words_out.append('[UNK]')
                elif w not in ['<pad>', '<bos>', '<eos>', '<mask>']:
                    words_out.append(w)
            else:
                words_out.append(f'<UNK_{token}>')
        return ' '.join(words_out)

    @torch.no_grad()
    def generate_text(self, prompt_text, max_new_tokens=50, temperature=1.0,
                      top_k=None, top_p=None, do_sample=True, verbose=False):
        """
        Generate text from text prompt.
        """
        prompt_tokens = self.text_to_tokens(prompt_text)
        if verbose:
            print(f"Input text: '{prompt_text}'")
            print(f"Input tokens: {prompt_tokens}")

        input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0)

        generated_ids = self._generate_tokens(
            input_ids, max_new_tokens, temperature, top_k, top_p, do_sample, verbose
        )

        generated_text = self.tokens_to_text(generated_ids.squeeze())
        if verbose:
            print(f"Generated tokens: {generated_ids.squeeze().tolist()}")
            print(f"Generated text: '{generated_text}'")
        return generated_text

    def _generate_tokens(self, input_ids, max_new_tokens, temperature, top_k, top_p, do_sample, verbose):
        input_ids = input_ids.to(self.device)
        eos_id = self.special_tokens.get('eos_id', None)
        max_len = self.config.get('max_seq_length', 1024)

        # Truncate prompt if it exceeds max length (keep most recent context)
        if input_ids.size(1) > max_len:
            input_ids = input_ids[:, -max_len:]

        # Validate sampling parameters
        if top_k is not None and top_k <= 0:
            top_k = None
        if top_p is not None:
            top_p = float(top_p)
            if not (0.0 < top_p <= 1.0):
                raise ValueError("top_p must be in (0, 1].")

        for step in range(max_new_tokens):
            if input_ids.size(1) >= max_len:
                if verbose: print(f"Reached max sequence length at step {step}")
                break

            logits, _ = self.model(input_ids)
            # model returns logits with shape (B, T, V), where B is batch size, T is sequence length, and V is vocabulary size,
            # last element in the sequence correspond to next token
            next_token_logits = logits[:, -1, :] / max(temperature, 1e-8)

            if top_k is not None:
                next_token_logits = self._apply_top_k(next_token_logits, top_k)
            if top_p is not None:
                next_token_logits = self._apply_top_p(next_token_logits, top_p)

            if do_sample and temperature > 0:
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

            # Batch=1 assumption: stop when EOS is generated
            if eos_id is not None and (next_token == eos_id).all():
                if verbose: print(f"EOS generated at step {step}")
                break

        return input_ids

    def _apply_top_k(self, logits, top_k):
        """Apply top-k filtering to limit vocabulary to k most likely tokens."""
        top_k = min(int(top_k), logits.size(-1))
        values, _ = torch.topk(logits, top_k)
        min_values = values[:, -1, None]
        return torch.where(logits < min_values, torch.full_like(logits, -float('inf')), logits)

    def _apply_top_p(self, logits, top_p):
        """Apply nucleus sampling: keep tokens until cumulative probability exceeds top_p."""
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Mask tokens to remove
        sorted_indices_to_remove = cumulative_probs > top_p
        # Ensure at least one token remains
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Scatter back to original indices
        indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
        indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)

        return logits.masked_fill(indices_to_remove, -float('inf'))

In [None]:
def print_generation_comparison(generator, prompts, temperatures, max_new_tokens=30):
    # Define strategies 
    strategies = [
        ("Greedy",        {"do_sample": False}),
        ("Top-k (None)",  {"do_sample": True}),                   # sampling with temperature only
        ("Top-k k=40",    {"do_sample": True, "top_k": 40}),
        ("Top-k k=100",   {"do_sample": True, "top_k": 100}),
        ("Top-p (None)",  {"do_sample": True}),                   # sampling with temperature only
        ("Top-p p=0.9",   {"do_sample": True, "top_p": 0.9}),
        ("Top-p p=0.95",  {"do_sample": True, "top_p": 0.95}),
    ]

    for prompt in prompts:
        print("Prompt:")
        print(prompt)       

        for temp in temperatures:
            print(f"\n--- Temperature: {temp:.2f} ---\n")
            for name, params in strategies:
                try:
                    text = generator.generate_text(
                        prompt_text=prompt,
                        max_new_tokens=max_new_tokens,
                        temperature=temp,
                        top_k=params.get("top_k", None),
                        top_p=params.get("top_p", None),
                        do_sample=params.get("do_sample", True),
                        verbose=False
                    )
                    print(f"[{name}]")
                    print(text.strip() + "\n")
                except Exception as e:
                    print(f"[{name}]")
                    print(f"Error: {e}\n")




In [None]:
# Prompts and temperatures (matching your example)
prompts = [
        "fed officials signaled that",
        # add more prompts if desired
    ]

temperatures = [0.70, 0.90, 1.10]

generator = GPTTextGenerator(model_path='trained_gpt_model.pt', vocab_filename='vocabulary.txt', device=None)

print_generation_comparison(generator, prompts, temperatures, max_new_tokens=30)