In [10]:
import random
import re
from collections import defaultdict, Counter
from typing import Dict, List, Tuple
import bisect

class TrigramModel:
    def __init__(self):
        self.trigram_counts = defaultdict(Counter)
        
        self.context_totals = defaultdict(int)
        
        # Special tokens
        self.START = "<START>"
        self.END = "<END>"
        self.UNK = "<UNK>"
        
        self.vocab = set()
        self.min_word_freq = 2  
        self.word_freq = Counter()
    
    def _clean_text(self, text: str) -> str:
        # Convert to lowercase
        text = text.lower()
        
        # Replace newlines and tabs with spaces
        text = re.sub(r'[\n\t\r]+', ' ', text)
        
        text = re.sub(r'[^a-z\s\.\!\?\,\;\:\'\-]', '', text)
        
        # Normalize multiple spaces
        text = re.sub(r'\s+', ' ', text)
        
        return text.strip()
    
    def _tokenize(self, text: str) -> List[str]:
        tokens = re.findall(r'\w+(?:\'\w+)?|[.!?]', text)
        return tokens
    
    def _pad_text(self, tokens: List[str]) -> List[str]:
        padded = []
        
        # Add two start tokens
        padded.extend([self.START, self.START])
        
        for token in tokens:
            padded.append(token)
            if token in {'.', '!', '?'}:
                padded.append(self.END)
                padded.extend([self.START, self.START])
        
        if padded[-1] != self.END:
            padded.append(self.END)
        
        return padded
    
    def fit(self, text: str):
        # Step 1: Clean the text
        cleaned = self._clean_text(text)
        
        # Step 2: Tokenize
        tokens = self._tokenize(cleaned)
        
        # Step 3: Count word frequencies for vocabulary
        self.word_freq.update(tokens)
        
        # Build vocabulary 
        self.vocab = {word for word, count in self.word_freq.items() 
                      if count >= self.min_word_freq}
        self.vocab.update([self.START, self.END])
        
        # Replace rare words with <UNK>
        tokens = [word if word in self.vocab else self.UNK for word in tokens]
        
        # Step 4: Pad with start/end tokens
        padded = self._pad_text(tokens)
        
        for i in range(len(padded) - 2):
            w1, w2, w3 = padded[i], padded[i+1], padded[i+2]
            context = (w1, w2)
            
            self.trigram_counts[context][w3] += 1
            self.context_totals[context] += 1
    
    def _get_next_word(self, context: Tuple[str, str]) -> str:
        if context not in self.trigram_counts:
            return random.choice(list(self.vocab))
        
        next_words = self.trigram_counts[context]
        total = self.context_totals[context]
      
        words = list(next_words.keys())
        cumulative = []
        cumsum = 0
        
        for word in words:
            cumsum += next_words[word] / total
            cumulative.append(cumsum)
        
        rand_val = random.random()
        idx = bisect.bisect_left(cumulative, rand_val)
        
        if idx >= len(words):
            idx = len(words) - 1
            
        return words[idx]
    
    def generate(self, max_length: int = 50, seed: int = None) -> str:
        if seed is not None:
            random.seed(seed)
        
        # Check if model is trained
        if not self.trigram_counts:
            return "Model not trained. Please call fit() first."
        
        # Start with two START tokens
        generated = [self.START, self.START]
        
        while len(generated) < max_length + 2:
            context = (generated[-2], generated[-1])
            next_word = self._get_next_word(context)
            
            generated.append(next_word)
            
            # Stop at END token
            if next_word == self.END:
                break
        
        result = [word for word in generated[2:] 
                  if word not in {self.START, self.END, self.UNK}]
        
        # Join words and clean up spacing around punctuation
        text = ' '.join(result)
        text = re.sub(r'\s+([.,!?;:])', r'\1', text)
        
        if text:
            text = text[0].upper() + text[1:]
        
        return text
    
    def get_model_stats(self) -> Dict:
        """Returns statistics about the trained model."""
        return {
            'unique_contexts': len(self.trigram_counts),
            'vocabulary_size': len(self.vocab),
            'total_trigrams': sum(self.context_totals.values()),
            'avg_next_words_per_context': 
                sum(len(words) for words in self.trigram_counts.values()) / 
                max(len(self.trigram_counts), 1)
        }


import requests
import time
BOOKS = {
    'alice': {
        'url': 'https://www.gutenberg.org/files/11/11-0.txt',
        'title': "Alice's Adventures in Wonderland"
    },
    'pride': {
        'url': 'https://www.gutenberg.org/files/1342/1342-0.txt',
        'title': 'Pride and Prejudice'
    },
    'frankenstein': {
        'url': 'https://www.gutenberg.org/files/84/84-0.txt',
        'title': 'Frankenstein'
    },
    'tale': {
        'url': 'https://www.gutenberg.org/files/98/98-0.txt',
        'title': 'A Tale of Two Cities'
    }
}

def download_book(url: str, title: str) -> str:
    print(f"Downloading '{title}'...")
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        print(f" Successfully downloaded '{title}' ({len(response.text)} characters)")
        return response.text
    except requests.RequestException as e:
        print(f"Error downloading '{title}': {e}")
        return ""

def clean_gutenberg_text(text: str) -> str: 
    start_markers = [
        "*** START OF THIS PROJECT GUTENBERG",
        "*** START OF THE PROJECT GUTENBERG",
        "*END*THE SMALL PRINT"
    ]
    
    start_idx = 0
    for marker in start_markers:
        idx = text.find(marker)
        if idx != -1:
            # Skip past the marker line
            start_idx = text.find('\n', idx) + 1
            break
    
    # Find the end of the content
    end_markers = [
        "*** END OF THIS PROJECT GUTENBERG",
        "*** END OF THE PROJECT GUTENBERG",
        "End of Project Gutenberg",
        "End of the Project Gutenberg"
    ]
    
    end_idx = len(text)
    for marker in end_markers:
        idx = text.find(marker)
        if idx != -1:
            end_idx = idx
            break
    
    # Extract the main content
    content = text[start_idx:end_idx]
    
    return content.strip()

def train_model(book_key: str = 'pride') -> TrigramModel:
    if book_key not in BOOKS:
        raise ValueError(f"Unknown book key: {book_key}. Choose from {list(BOOKS.keys())}")
    
    book_info = BOOKS[book_key]
    
    # Download the book
    raw_text = download_book(book_info['url'], book_info['title'])
    
    if not raw_text:
        raise RuntimeError("Failed to download book")
    
    # Clean Gutenberg metadata
    text = clean_gutenberg_text(raw_text)
    print(f"Cleaned text: {len(text)} characters")
    
    # Train the model
    print("\nTraining trigram model...")
    start_time = time.time()
    
    model = TrigramModel()
    model.fit(text)
    
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")
    
    # Display model statistics
    stats = model.get_model_stats()
    print("\nModel Statistics:")
    print(f"  Vocabulary size: {stats['vocabulary_size']:,}")
    print(f"  Unique contexts: {stats['unique_contexts']:,}")
    print(f"  Total trigrams: {stats['total_trigrams']:,}")
    print(f"  Avg next words per context: {stats['avg_next_words_per_context']:.2f}")
    
    return model

def generate_samples(model: TrigramModel, num_samples: int = 5, max_length: int = 50):
  
    print(f"\n{'='*70}")
    print(f"Generating {num_samples} sample texts:")
    print(f"{'='*70}\n")
    
    for i in range(num_samples):
        text = model.generate(max_length=max_length, seed=i)
        print(f"Sample {i+1}:")
        print(f"{text}\n")
        print(f"{'-'*70}\n")

def main():
    print("="*70)
    print("Trigram Language Model - Training Script")
    print("="*70)
    print("\nAvailable books:")
    for key, info in BOOKS.items():
        print(f"  {key}: {info['title']}")
    
    book_choice = 'pride' #'alice'  # Change to 'pride'#, 'frankenstein', or 'tale'
    
    print(f"\nTraining on: {BOOKS[book_choice]['title']}")
    print("="*70 + "\n")
    
    # Train the model
    model = train_model(book_choice)
    
    # Generate sample texts
    generate_samples(model, num_samples=5, max_length=50)
    
    print("\n All done")

if __name__ == "__main__":
    main()

Trigram Language Model - Training Script

Available books:
  alice: Alice's Adventures in Wonderland
  pride: Pride and Prejudice
  frankenstein: Frankenstein
  tale: A Tale of Two Cities

Training on: Pride and Prejudice

Downloading 'Pride and Prejudice'...
 Successfully downloaded 'Pride and Prejudice' (743383 characters)
Cleaned text: 743241 characters

Training trigram model...
Training completed in 0.71 seconds

Model Statistics:
  Vocabulary size: 4,170
  Unique contexts: 51,717
  Total trigrams: 157,275
  Avg next words per context: 2.01

Generating 5 sample texts:

Sample 1:
Jane ran to her partner to oblige him to visit her the information herself and her expectations of felicity to have his errors made public might ruin him in.

----------------------------------------------------------------------

Sample 2:
I can from my poor mother is tolerably well i suppose who have heart enough to make another choice?

-------------------------------------------------------------------