In [1]:
pip install --upgrade datasets nltk tqdm

Collecting nltk
  Obtaining dependency information for nltk from https://files.pythonhosted.org/packages/60/90/81ac364ef94209c100e12579629dc92bf7a709a84af32f8c551b02c07e94/nltk-3.9.2-py3-none-any.whl.metadata
  Downloading nltk-3.9.2-py3-none-any.whl.metadata (3.2 kB)
Downloading nltk-3.9.2-py3-none-any.whl (1.5 MB)
   ---------------------------------------- 0.0/1.5 MB ? eta -:--:--
    --------------------------------------- 0.0/1.5 MB 1.3 MB/s eta 0:00:02
   -- ------------------------------------- 0.1/1.5 MB 871.5 kB/s eta 0:00:02
   -- ------------------------------------- 0.1/1.5 MB 871.5 kB/s eta 0:00:02
   ---- ----------------------------------- 0.2/1.5 MB 919.0 kB/s eta 0:00:02
   -------- ------------------------------- 0.3/1.5 MB 1.3 MB/s eta 0:00:01
   -------- ------------------------------- 0.3/1.5 MB 1.3 MB/s eta 0:00:01
   -------- ------------------------------- 0.3/1.5 MB 981.5 kB/s eta 0:00:02
   --------------- ------------------------ 0.6/1.5 MB 1.6 MB/s eta 0:00:

In [5]:
import nltk
nltk.download('punkt_tab')
nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\sharv\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt_tab.zip.
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\sharv\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\sharv\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [10]:
import pandas as pd
import numpy as np
import math
import re
from datasets import load_dataset # Make sure 'datasets' is up-to-date
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer
from collections import Counter, defaultdict
from tqdm.auto import tqdm

# Ensure NLTK resources are downloaded (run this once if you get a resource error)
# import nltk
# nltk.download('punkt')
# nltk.download('stopwords')

# --- 1. CONFIGURATION AND UTILITIES ---
# Constants for cleaning and dictionary building
STOP_WORDS = set(stopwords.words('english'))
STEMMER = PorterStemmer()
TOP_N_TRANSLATIONS = 100 
# Q_A_SUBSET_NAME is no longer used in load_dataset, but kept for context.
Q_A_SUBSET_NAME = 'question-answer-pair' 

def clean_text(text, stem=True):
    """Applies the cleaning pipeline: lowercase, punct. removal, stopword removal, and stemming."""
    if not isinstance(text, str):
        return []
    
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text) 
    tokens = word_tokenize(text)
    
    processed_tokens = []
    for word in tokens:
        if word and word not in STOP_WORDS:
            if stem:
                processed_tokens.append(STEMMER.stem(word))
            else:
                processed_tokens.append(word)
                
    return processed_tokens

# --- 2. PROBABILISTIC DICTIONARY (TRLM CORE) ---

def calculate_p_mi_dictionary(qa_pairs, top_n):
    """
    Calculates the P_MI(w_A | w_Q) translation probabilities.
    
    Args: qa_pairs: List of tuples [(Q_tokens, A_tokens), ...]
    """
    if not qa_pairs:
        return defaultdict(dict)
    
    total_q_words = 0
    total_a_words = 0
    q_counts = Counter()
    a_counts = Counter()
    co_occurrence = defaultdict(Counter)
    
    # 1. Compute all counts and co-occurrences
    print(f"Counting occurrences across {len(qa_pairs):,} Q-A pairs...")
    for q_tokens, a_tokens in tqdm(qa_pairs):
        q_counts.update(q_tokens)
        a_counts.update(a_tokens)
        total_q_words += len(q_tokens)
        total_a_words += len(a_tokens)

        # Co-occurrence is counted only once per unique word pair per Q-A instance
        for wq in set(q_tokens):
            for wa in set(a_tokens):
                co_occurrence[wq][wa] += 1
                
    N_PAIRS = len(qa_pairs)
    p_mi_dictionary = defaultdict(dict)
    epsilon = 1e-9

    # 2. Compute MI and Normalize to P_MI
    print("Calculating and normalizing Mutual Information (MI)...")
    for wq, a_map in tqdm(co_occurrence.items(), desc="Calculating P_MI"):
        # Handle division by zero for total word counts if they somehow ended up as zero
        P_WQ = q_counts[wq] / (total_q_words + epsilon)
        unnormalized_mi_scores = {}
        
        for wa, co_count in a_map.items():
            P_WA = a_counts[wa] / (total_a_words + epsilon)
            P_WQ_WA = co_count / N_PAIRS
            
            # MI calculation core: log(P(wQ, wA) / (P(wQ) * P(wA)))
            # Add epsilon to the denominator to prevent division by zero near-zero probabilities
            score = P_WQ_WA / (P_WQ * P_WA + epsilon)
            MI_score = max(0, math.log(score, 2))
            unnormalized_mi_scores[wa] = MI_score

        # 3. Normalize and store the top N candidates
        total_mi_for_wq = sum(unnormalized_mi_scores.values())
        
        if total_mi_for_wq > 0:
            sorted_candidates = sorted(
                unnormalized_mi_scores.items(), key=lambda item: item[1], reverse=True
            )
            
            # Apply normalization: P_MI(w_A | w_Q) = MI / SUM(MI)
            for wa, mi_score in sorted_candidates[:top_n]:
                p_mi_dictionary[wq][wa] = mi_score / total_mi_for_wq
    
    return p_mi_dictionary

# --- 3. DATA ACQUISITION AND EXECUTION (FIXED) ---

# --- 3. DATA ACQUISITION AND EXECUTION (FIXED AGAIN) ---

# --- 3. DATA ACQUISITION AND EXECUTION (FIXED) ---

def run_data_preparation():
    print("--- Phase 1: Data Acquisition and Pre-processing for TRLM ---")
    
    # ... (Loading data code remains the same) ...
    try:
        raw_dataset = load_dataset(
            "sentence-transformers/yahoo-answers", 
            name=Q_A_SUBSET_NAME,
            split='train[:100000]'
        )
    except Exception as e:
        print(f"\nFATAL ERROR LOADING DATASET. Error: {e}")
        # Return an empty dict if loading fails, so the main block knows it failed.
        return {} 
    
    # 2. Clean and Tokenize Q-A pairs
    # ... (Cleaning and tokenizing logic remains the same, builds cleaned_qa_pairs) ...
    cleaned_qa_pairs = []
    
    for item in tqdm(raw_dataset, desc="Cleaning data"):
        question_text = item.get('question', '')
        answer_text = item.get('answer', '')
        
        q_tokens = clean_text(question_text, stem=True)
        a_tokens = clean_text(answer_text, stem=True)
        
        if q_tokens and a_tokens:
            cleaned_qa_pairs.append((q_tokens, a_tokens))
    
    print(f"\nTotal usable cleaned Q-A pairs: {len(cleaned_qa_pairs):,}")
    
    # 3. Build the General Probabilistic Dictionary
    print("\n3. Building the General Probabilistic Dictionary...")
    general_dictionary = calculate_p_mi_dictionary(cleaned_qa_pairs, TOP_N_TRANSLATIONS)

    # 4. Verification (Optional, but good practice)
    test_word = 'comput'
    if test_word in general_dictionary:
         # ... (Verification printing code) ...
         pass # keep the verification code if you like

    print("\nPhase 1 (Data Prep) Complete: General dictionary created. Next is Category Prediction.")
    
    # CRITICAL FIX: RETURN THE DICTIONARY
    return general_dictionary


# if __name__ == "__main__":
#     run_data_preparation()

In [11]:
# --- 4. CATEGORY PREDICTION AND TESTING UTILITY ---

def predict_category_score(query, category_words, dictionary, top_n):
    """
    Calculates the TRLM score for a query against a set of category-specific words.
    
    The score is the sum of P_MI(w_A | w_Q) for all w_Q in the query and 
    all w_A in the category words.
    
    Args:
        query (str): The raw user question string.
        category_words (list): List of stemmed words representing the category/answer.
        dictionary (defaultdict): The P_MI dictionary (P_MI(w_A | w_Q)).
        top_n (int): Used to filter the dictionary's candidates.
        
    Returns:
        float: The final TRLM relevance score.
    """
    query_tokens = clean_text(query, stem=True)
    total_score = 0.0
    
    if not query_tokens:
        return 0.0

    # 1. Iterate through every word in the Query (w_Q)
    for wq in query_tokens:
        if wq in dictionary:
            # Get the top N translations for w_Q
            translations = dictionary[wq]
            
            # 2. Check overlap with the target Category Words (w_A)
            for wa in category_words:
                # If the Category Word (w_A) is one of the top translations for w_Q
                if wa in translations:
                    # Add the probability P_MI(w_A | w_Q) to the total score
                    total_score += translations[wa]
                    
    # Normalize by the length of the query to prevent longer queries from always winning
    return total_score / len(query_tokens)

# --- 5. EXECUTION EXTENSION (TESTING) ---

def run_prediction_test(general_dictionary):
    """Tests the TRLM dictionary on a set of sample queries and categories."""
    
    print("\n" + "="*50)
    print("--- Phase 2: Testing TRLM for Category Prediction ---")
    print("="*50)
    
    # 1. Define Test Categories and their representative words (stemmed)
    # These words simulate a potential answer/category name.
    # Note: 'stem' is added to words that are already stemmed in the dictionary
    categories = {
        'Technology': ['comput', 'softwar', 'internet', 'hardwar', 'game'],
        'Sport': ['footbal', 'basketbal', 'game', 'team', 'ball'],
        'Finance': ['money', 'bank', 'invest', 'stock', 'loan']
    }
    
    # 2. Define Sample Queries
    sample_queries = [
        "What is the best way to invest my money in the stock market?",
        "How do I install new computer software and games on my laptop?",
        "Who won the football game last night?"
    ]
    
    # 3. Predict Scores for each Query against all Categories
    for i, query in enumerate(sample_queries):
        print(f"\nQUERY {i+1}: '{query}'")
        scores = {}
        
        for category_name, category_words in categories.items():
            score = predict_category_score(
                query, 
                category_words, 
                general_dictionary, 
                TOP_N_TRANSLATIONS
            )
            scores[category_name] = score
        
        # Determine the best match
        if scores:
            best_category = max(scores, key=scores.get)
            
            print("--- Relevance Scores ---")
            for cat, score in sorted(scores.items(), key=lambda item: item[1], reverse=True):
                 print(f"  {cat}: {score:.5f}{' <--- BEST MATCH' if cat == best_category else ''}")
        else:
            print("Could not generate scores (Dictionary might be empty).")

In [None]:
# --- 5. EXECUTION EXTENSION (TESTING) ---

if __name__ == "__main__":
    # Part 1: Data Preparation and Dictionary Building
    general_dictionary = run_data_preparation()
    
    # Part 2: Prediction Test
    if general_dictionary:
        run_prediction_test(general_dictionary)
    else:
        print("\nDictionary is empty. Cannot run prediction test.")

--- Phase 1: Data Acquisition and Pre-processing for TRLM ---


Cleaning data:   0%|          | 0/100000 [00:00<?, ?it/s]