In [2]:
import pandas as pd
import re
import os
import pickle
from tqdm import tqdm
from collections import Counter, defaultdict
import itertools
import nltk
from nltk.util import ngrams
from nltk.metrics import edit_distance
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from sklearn.metrics import classification_report
from nltk.probability import FreqDist
import string

nltk.download('punkt')

BASE_DIR = '/Users/bathulahoneypriya/Documents/graph/telugu'

BASE_SUFFIXES = [
    'తూ', 'నే', 'ను', 'డి', 'లే', 'మూ', 'నూ', 'గ', 'కూ', 'వు', 'ఆ', 'ల', 'యు', 'తో', 'అని', 'మా', 'త', 
    'ఆరు', 'లో', 'ఆల్', 'గవ', 'కు', 'గారు', 'జీ', 'అన్న', 'చి', 'తను', 'నూ', 'కూ', 'పై', 'వద్ద', 
    'పాల', 'మీద', 'నుంచి', 'తొ', 'పాటు', 'గా', 'ంచి', 'కొరకు', 'లాగా', 'వల్ల', 'కంటే', 'లోన', 'లోపల', 
    'కింద', 'వరకు', 'నుండి', 'తనం', 'త్వం', 'త్వము', 'డు', 'ము', 'వి', 'లు', 'ాను', 'ారు', 'ెను', 
    'ుడు', 'ుని', 'ులు', 'ాలు', 'ాడు', 'ిన', 'ని', 'నకు', 'రాలు', 'ించి', 'ించు', 'ించే', 'ిస్తే', 
    'ిస్తూ', 'గాను', 'టకు', 'కై', 'టపై', 'ంటే', 'కాల', 'లోని', 'నకైనా', 'నకైన', 'రాలి', 'వారు', 
    'వార', 'వారి', 'వాడు', 'ావు', 'ింది', 'ారము', 'తాను', 'తారు', 'తాము', 'తావు', 'స్తాను', 'స్తావు', 
    'స్తాడు', 'స్తుంది', 'స్తారు', 'స్తాము', 'లోకి', 'నించి', 'పైన', 'వెనుక', 'ముందు', 'చేత', 'ద్వారా', 
    'వలె', 'మైన', 'వంటి', 'లాంటి', 'చేసిన', 'అయిన', 'పడిన', 'గల', 'కల', 'లేదు', 'కాదు', 'వద్దు', 
    'కూడదు', 'లేని', 'కాని', 'ఏమిటి', 'ఏమి', 'ఎందుకు', 'ఎప్పుడు', 'ఎక్కడ', 'ఎలా', 'చేయవచ్చు', 
    'చేయాలి', 'చేయకూడదు', 'వేయాలి', 'పెట్టాలి', 'తెలుసుకోవాలి', 'ఆడు', 'ఇడు', 'ఉండు', 'పూర్వకం', 
    'పరిపాటి', 'తున్నారు', 'ఆడుతున్నారు',
    'పూర్వక', 'పూర్తి', 'ాయి', 'ాము', 'కొనుట', 'ించాలి', 'ిస్తారు', 'ేయండి', 'తున్నాను', 'తున్నాము',
    'తున్నాడు', 'తుంది', 'తున్నది', 'లోనికి', 'నుంచే', 'నందు', 'కోసం', 'ద్వారా', 'తరువాత', 'ముందర',
    'పర్యంతం', 'వరకూ', 'నుండే', 'అయితే', 'కాబట్టి', 'అందువలన', 'ఐనా', 'మరియు', 'కానీ', 'కాని', 
    'వాడం', 'చేయడం', 'పెట్టడం', 'రావడం', 'పోవడం', 'వెళ్లడం', 'ఉండటం', 'రాకపోవడం', 'ఇవ్వడం',
    'రాసినది', 'చదివినది', 'చూసినది', 'వచ్చినది', 'వినినది', 'తెలిసినది'
]

def augment_suffixes(suffixes):
    augmented = set(suffixes)
    
    for suffix in suffixes:
        if len(suffix) <= 6: 
            for char in ['ా', 'ి', 'ు', 'ె', 'ొ', 'ో', 'ం', 'ః']:
                augmented.add(suffix + char)
    
    for suf1, suf2 in itertools.product(suffixes, suffixes):
        if len(suf1) <= 3 and len(suf2) <= 3:  
            augmented.add(suf1 + suf2)
    
    return sorted(list(augmented), key=len, reverse=True)

TELUGU_SUFFIXES = augment_suffixes(BASE_SUFFIXES)

def prune_ngram_models(self, threshold=3):
    """Remove rare n-grams to save memory"""
    for name, model in self.ngram_models.items():
        self.ngram_models[name] = FreqDist({k: v for k, v in model.items() if v >= threshold})

def detect_compound_words(self, word):
    """Detect and handle compound words in Telugu"""
    
    potential_joins = []
    for i in range(3, len(word)-3):
        left = word[:i]
        right = word[i:]
        if self.is_valid_stem(left) and self.is_valid_stem(right):
            potential_joins.append((left, right, 
                                  self.compute_stem_quality(left, left) + 
                                  self.compute_stem_quality(right, right)))
    
    if potential_joins:
        best_split = max(potential_joins, key=lambda x: x[2])
        if best_split[2] > 1.5: 
            return best_split[0]  
    
    return None

class TeluguStemmer:
    def __init__(self):
        self.suffixes = TELUGU_SUFFIXES
        self.exception_dict = {}
        self.ngram_models = {}
        self.min_stem_length = 2
        self.max_suffix_length = 8  
        self.training_data = []
        self.similarity_cache = {}
        self.rule_weight = 1.0
        self.stat_weight = 1.0
        self.verbose = True
        
        self.root_forms = {
            'చేయ': 'చేయు', 'పోవ': 'పోవు', 'రావ': 'రావు', 'వెళ్ల': 'వెళ్లు',
            'తిన': 'తిను', 'అడుగ': 'అడుగు', 'చదువ': 'చదువు', 'రాయ': 'రాయు',
            'చూడ': 'చూడు', 'విన': 'విను', 'మాట్లాడ': 'మాట్లాడు', 'నడవ': 'నడవు'
        }
        
        self.sandhi_rules = [
            (r'([క-హ])ు\s+అ', r'\1వ'),
            (r'([క-హ])ి\s+అ', r'\1్య'),
            (r'([క-హ])ా\s+అ', r'\1ాయ'),
        ]

        additional_sandhi_rules = [
            (r'([క-హ])్\s+[అ]', r'\1'), 
            (r'([ఎఏ])\s+([ఇఈ])', r'ఐ'),  
            (r'([ఒఓ])\s+([ఉఊ])', r'ఔ'),
          
        ]
        self.sandhi_rules.extend(additional_sandhi_rules)
        
        self.vowel_groups = {
            'front': set(['ి', 'ీ', 'ె', 'ే', 'ై']),
            'back': set(['ు', 'ూ', 'ొ', 'ో', 'ౌ']),
            'neutral': set(['ా', 'అ', 'ఆ', 'ఇ', 'ఈ', 'ఉ', 'ఊ', 'ఎ', 'ఏ', 'ఐ', 'ఒ', 'ఓ', 'ఔ'])
        }

        additional_roots = {
            'వస్తా': 'వచ్చు', 'పోతా': 'పోవు', 'చేస్తా': 'చేయు',
            'రాస్తా': 'రాయు', 'చూస్తా': 'చూడు', 'వింటా': 'విను',
            'అడుగుతా': 'అడుగు', 'తింటా': 'తిను', 'నడుస్తా': 'నడువు'
        }
        self.root_forms.update(additional_roots)

    def is_telugu_consonant(self, char):
        """Check if a character is a Telugu consonant"""
        return char in 'కఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరలవశషసహళ'

    def is_telugu_vowel(self, char):
        """Check if a character is a Telugu vowel"""
        return char in 'అఆఇఈఉఊఋఌఎఏఐఒఓఔ'

    def is_telugu_vowel_sign(self, char):
        """Check if a character is a Telugu vowel sign"""
        return char in 'ాిీుూృౄెేైొోౌ్'
    

    def apply_morphological_rules(self, word, stem):
        if not isinstance(stem, str):
            return word
            
        if stem in self.root_forms:
            return self.root_forms[stem]
            
        if word.endswith('ము') and not stem.endswith('ము') and stem.endswith('ు'):
            return stem[:-1] + 'ము'
        if word.endswith('కం') and not stem.endswith('కం') and len(stem) > 2:
            return stem + 'క'
        if word.endswith('తం') and not stem.endswith('తం') and len(stem) > 2:
            return stem + 'త'
            
        if word.endswith(('తున్నాడు', 'స్తున్నాడు', 'ఆడు', 'స్తున్నాను', 'తున్నాను')):
            if stem.endswith(('తున్న', 'స్తు', 'ఆ')):
                return stem[:-len(stem[-4:]) if len(stem) >= 4 else len(stem)] + 'ు'
                
        if word.endswith(('వైన', 'మైన')) and not stem.endswith(('వైన', 'మైన')):
            return stem + word[len(stem):len(stem)+3]
            
        
        
        if word.endswith(('స్తున్నాను', 'స్తున్నావు', 'స్తున్నాడు', 'స్తున్నాము', 'స్తున్నారు')):
         
            return stem[:-len('స్తున్న')] + 'ు'

        if word.endswith(('ించాను', 'ించావు', 'ించాడు', 'ించాము', 'ించారు')):
        
            return stem[:-len('ించ')] + 'ు'

            
        return stem

    def context_aware_stem(self, word, context_words=None):
        """Use surrounding words to improve stemming accuracy"""
        if not context_words or len(context_words) == 0:
            return self.hybrid_stem(word)
        
        is_likely_verb = any(cw.endswith(('చేసింది', 'చేశారు', 'ఉంది', 'ఉన్నారు')) for cw in context_words)
        is_likely_noun = any(cw in ('ఒక', 'ఆ', 'ఈ', 'మన', 'ఏ', 'కొన్ని') for cw in context_words)
        
        if is_likely_verb:
            temp_rule_weight = self.rule_weight * 1.2 
            temp_stat_weight = self.stat_weight * 0.9
        elif is_likely_noun:
            temp_rule_weight = self.rule_weight * 0.9
            temp_stat_weight = self.stat_weight * 1.1 
        else:
            temp_rule_weight = self.rule_weight
            temp_stat_weight = self.stat_weight
        
        orig_rule_weight = self.rule_weight
        orig_stat_weight = self.stat_weight
        
        self.rule_weight = temp_rule_weight
        self.stat_weight = temp_stat_weight
        
        result = self.hybrid_stem(word)
        
        self.rule_weight = orig_rule_weight
        self.stat_weight = orig_stat_weight
        
        return result
    
    def detailed_error_analysis(self, test_data):
        """More detailed error analysis by word type"""
      
        length_errors = defaultdict(int)
        length_totals = defaultdict(int)
        
        ending_errors = defaultdict(int)
        ending_totals = defaultdict(int)
        
        for word, true_stem in test_data:
            length = len(word)
            length_totals[length] += 1
            
            ending = word[-2:] if len(word) >= 2 else word
            ending_totals[ending] += 1
            
            pred_stem = self.hybrid_stem(word)
            if pred_stem != true_stem:
                length_errors[length] += 1
                ending_errors[ending] += 1
        
        length_rates = {k: v/length_totals[k] for k, v in length_errors.items() if length_totals[k] > 0}
        ending_rates = {k: v/ending_totals[k] for k, v in ending_errors.items() if ending_totals[k] >= 5}
        
        return {
            'length_error_rates': length_rates,
            'problematic_endings': sorted(ending_rates.items(), key=lambda x: x[1], reverse=True)[:10]
        }
    
    def normalize_input(self, word):
        """Normalize Telugu input to handle variations"""
       
        word = re.sub(r'ం', 'ం', word)  
        word = re.sub(r'ఁ', 'ం', word) 
        
        word = re.sub(r'ఽ', '', word)
        
        word = re.sub(r'([క-హ])్([ాిీుూృౄెేైొోౌ])', r'\1\2్', word)
        
        return word

    def analyze_errors(self, test_data):
        error_data = []
        for word, true_stem in test_data:
            if not isinstance(word, str) or not isinstance(true_stem, str):
                continue
                
            predicted_stem = self.hybrid_stem(word)
            if predicted_stem != true_stem:
                removed_suffix = word[len(predicted_stem):] if len(predicted_stem) < len(word) else ''
                error_data.append({
                    'word': word,
                    'length': len(word),
                    'predicted_stem': predicted_stem,
                    'true_stem': true_stem,
                    'removed_suffix': removed_suffix,
                    'edit_distance': edit_distance(predicted_stem, true_stem),
                    'category': self.categorize_error(word, predicted_stem, true_stem)
                })
                
        return pd.DataFrame(error_data)

    def categorize_error(self, word, predicted, true):
        if predicted == word and true != word:
            return "No_suffix_removed"
        elif len(predicted) > len(true):
            return "Understemming"
        elif len(predicted) < len(true):
            return "Overstemming"
        else:
            return "Wrong_stem"
        


    def update_from_errors(self, error_df, update_threshold=3):
       
        high_confidence_fixes = []
        medium_confidence_fixes = []
        
        hybrid_failures = []
        for _, row in error_df.iterrows():
            word = row['word']
            true_stem = row['true_stem']
            hybrid_pred = self.hybrid_stem(word)
            rule_pred = self.rule_based_stem(word)
            stat_pred = self.statistical_stem(word)
            
            if hybrid_pred != true_stem:
                if rule_pred == true_stem or stat_pred == true_stem:
                    hybrid_failures.append((word, true_stem, rule_pred, stat_pred))
                    high_confidence_fixes.append((word, true_stem))
                elif self.similarity_score(word, rule_pred, true_stem) > 0.7 or self.similarity_score(word, stat_pred, true_stem) > 0.7:
                    medium_confidence_fixes.append((word, true_stem))
        
        for word, true_stem in high_confidence_fixes[:min(5, len(high_confidence_fixes))]:
            self.exception_dict[word] = true_stem
            
        for word, true_stem in medium_confidence_fixes[:min(3, len(medium_confidence_fixes))]:
            self.exception_dict[word] = true_stem
        
        new_exceptions = []
        max_exceptions_per_iteration = 13
        
        for _, row in error_df.iterrows():
            if len(new_exceptions) >= max_exceptions_per_iteration:
                break
                
            if row['edit_distance'] <= update_threshold or (row['category'] == "Wrong_stem" and row['edit_distance'] <= update_threshold + 1):
                new_exceptions.append((row['word'], row['true_stem']))
        
        num_new = 0
        for word, true_stem in new_exceptions:
            self.exception_dict[word] = true_stem
            num_new += 1
            
        overstemming = sum(1 for _, row in error_df.iterrows() if row['category'] == "Overstemming")
        understemming = sum(1 for _, row in error_df.iterrows() if row['category'] == "Understemming")
        
        total_errors = len(error_df)
        if total_errors > 0:
            if overstemming / total_errors > 0.6:
                self.stat_weight = max(0.8, self.stat_weight - 0.05)
                self.rule_weight = min(1.2, self.rule_weight + 0.05)
            elif understemming / total_errors > 0.6:
                self.rule_weight = max(0.8, self.rule_weight - 0.05)
                self.stat_weight = min(1.2, self.stat_weight + 0.05)
        
        if self.verbose:
            print(f"Added {len(high_confidence_fixes)} high confidence fixes")
            print(f"Added {len(medium_confidence_fixes)} medium confidence fixes")
            print(f"Added {num_new} general exceptions")
            print(f"Updated weights - Rule: {self.rule_weight:.2f}, Stat: {self.stat_weight:.2f}")

        systematic_fixes = self.find_systematic_errors(error_df)
        for word, true_stem in systematic_fixes.items():
            self.exception_dict[word] = true_stem
        if self.verbose:
            print(f"Added {len(systematic_fixes)} systematic error fixes")


    def load_exceptions(self, exceptions_file):
        if os.path.exists(exceptions_file):
            try:
                df = pd.read_excel(exceptions_file)
                self.exception_dict = dict(zip(df['word'], df['stem']))
                if self.verbose:
                    print(f"Loaded {len(self.exception_dict)} exceptions")
            except Exception as e:
                print(f"Error loading exceptions: {e}")
                self.exception_dict = {}

    def train_ngram_models(self, words, max_workers=4, sample_size=1000000):
        print("Training n-gram models...")
        words = list(itertools.islice(words, sample_size))
        char_data = "".join(words)
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {
                'unigram': executor.submit(FreqDist, char_data),
                'bigram': executor.submit(FreqDist, ngrams(char_data, 2)),
                'trigram': executor.submit(FreqDist, ngrams(char_data, 3))
            }
            
            for name, future in tqdm(futures.items(), desc="Training models"):
                self.ngram_models[name] = future.result()
                
        print("N-gram models trained successfully")

    def is_valid_stem(self, stem):
        if not isinstance(stem, str) or len(stem) < self.min_stem_length:
            return False
            
        if not stem:
            return False
            
        valid_endings = ['ు', 'ి', 'ీ', 'ా', 'ె', 'ో', 'ూ', 'ే', 'అ', 'ఇ', 'ఈ', 'ఉ', 'ఊ', 'ఎ', 'ఏ', 'ఐ', 'ఒ', 'ఓ', 'ఔ',
                        'క', 'గ', 'చ', 'జ', 'ట', 'డ', 'త', 'ద', 'న', 'ప', 'బ', 'మ', 'య', 'ర', 'ల', 'వ', 'శ', 'స', 'హ']
        
        last_char = stem[-1]
        if last_char not in valid_endings:
            return False
            
        if len(stem) >= 3 and 'trigram' in self.ngram_models:
            trigram = tuple(stem[-3:])
            if self.ngram_models['trigram'].get(trigram, 0) < 2:  
                if len(stem) >= 2 and 'bigram' in self.ngram_models:
                    bigram = tuple(stem[-2:])
                    if self.ngram_models['bigram'].get(bigram, 0) < 3:  
                        return False
                        
        consonant_pattern = re.compile(r'[క-హ][క-హ][క-హ][క-హ]+')  # More restrictive pattern
        if consonant_pattern.search(stem):
            return False  
            
        return True

    def check_vowel_harmony(self, stem):
        if len(stem) < 3:
            return True
            
        vowels = [c for c in stem if c in 'ాిీుూెేైొోౌ']
        if not vowels:
            return True
            
        front_vowels = sum(1 for v in vowels if v in self.vowel_groups['front'])
        back_vowels = sum(1 for v in vowels if v in self.vowel_groups['back'])
        
        if front_vowels > 2 and back_vowels == 0:
            return True
        if back_vowels > 2 and front_vowels == 0:
            return True
            
        return True

    def rule_based_stem(self, word):
        if not isinstance(word, str):
            return word
            
        if word in self.exception_dict:
            return self.exception_dict[word]
            
        for suffix in self.suffixes:
            if word.endswith(suffix) and len(word) > len(suffix) + self.min_stem_length - 1:
                candidate_stem = word[:-len(suffix)]
                
                if self.is_valid_stem(candidate_stem) and self.check_vowel_harmony(candidate_stem):
                    return self.apply_morphological_rules(word, candidate_stem)
                    
        for pattern, replacement in self.sandhi_rules:
            modified = re.sub(pattern, replacement, word)
            if modified != word:
                return self.rule_based_stem(modified) 
                
        return word 

    def statistical_stem(self, word):
        if not isinstance(word, str):
            return word
            
        if word in self.exception_dict:
            return self.exception_dict[word]
            
        best_stem, best_score = word, -1
        max_suffix_len = min(self.max_suffix_length, len(word) - self.min_stem_length + 1)
        
        for suffix_len in range(1, max_suffix_len + 1):
            candidate_stem = word[:-suffix_len]
            
            if not self.is_valid_stem(candidate_stem):
                continue
                
            if len(candidate_stem) >= 3 and 'trigram' in self.ngram_models:
                trigram_score = 0
                for i in range(len(candidate_stem) - 2):
                    trigram = tuple(candidate_stem[i:i+3])
                    weight = (len(candidate_stem) - i) / len(candidate_stem)
                    trigram_score += self.ngram_models['trigram'].get(trigram, 0) * weight
            else:
                trigram_score = 0
                
            if len(candidate_stem) >= 2 and 'bigram' in self.ngram_models:
                bigram_score = 0
                for i in range(len(candidate_stem) - 1):
                    bigram = tuple(candidate_stem[i:i+2])
                    weight = (len(candidate_stem) - i) / len(candidate_stem)
                    bigram_score += self.ngram_models['bigram'].get(bigram, 0) * weight
            else:
                bigram_score = 0
                
            if 'unigram' in self.ngram_models:
                unigram_score = sum(self.ngram_models['unigram'].get(c, 0) for c in candidate_stem) / len(candidate_stem) if candidate_stem else 0
            else:
                unigram_score = 0
            
            combined_score = (trigram_score * 3 + bigram_score * 2 + unigram_score) / 6
            
            length_penalty = 0.8 if len(candidate_stem) <= 2 else 1.0
            
            final_score = combined_score * length_penalty
            
            if final_score > best_score:
                best_score, best_stem = final_score, candidate_stem
                
        return self.apply_morphological_rules(word, best_stem if best_score > 0 else word)

    def similarity_score(self, word, stem1, stem2):
        cache_key = f"{word}|{stem1}|{stem2}"
        
        if cache_key in self.similarity_cache:
            return self.similarity_cache[cache_key]
            
        if not isinstance(stem1, str) or not isinstance(stem2, str):
            return 0.0
            
        max_len = max(len(stem1), len(stem2))
        if max_len == 0:
            self.similarity_cache[cache_key] = 1.0
            return 1.0
            
        edit_sim = 1 - (edit_distance(stem1, stem2) / max_len)
        
        chars1 = set(stem1)
        chars2 = set(stem2)
        if not chars1 or not chars2:
            char_sim = 0.0
        else:
            char_sim = len(chars1.intersection(chars2)) / len(chars1.union(chars2))
            
        if len(stem1) >= 2 and len(stem2) >= 2:
            bigrams1 = set(ngrams(stem1, 2))
            bigrams2 = set(ngrams(stem2, 2))
            ngram_sim = len(bigrams1.intersection(bigrams2)) / len(bigrams1.union(bigrams2)) if bigrams1 and bigrams2 else 0.0
        else:
            ngram_sim = 0.0
            
        final_score = (edit_sim * 0.5) + (char_sim * 0.3) + (ngram_sim * 0.2)
        
        self.similarity_cache[cache_key] = final_score
        return final_score
    
    def detect_telugu_patterns(self, word):
        """Detect specific Telugu morphological patterns."""
        patterns = {
            'verb_present': re.compile(r'తున్నా(ను|వు|డు|ము|రు|)$'),
            'verb_past': re.compile(r'ా(ను|వు|డు|ము|రు|)$'),
            'verb_future': re.compile(r'తా(ను|వు|డు|ము|రు|)$'),
            'plural_noun': re.compile(r'(లు|ములు|వులు)$'),
            'case_markers': re.compile(r'(లో|కి|తో|పై|నుండి)$')
        }
        
        detected_patterns = {}
        for pattern_name, regex in patterns.items():
            if regex.search(word):
                detected_patterns[pattern_name] = True
        
        return detected_patterns

    def resolve_stem_disagreement(self, word, rule_stem, stat_stem):
        """Enhanced resolution for stem disagreements with linguistic insights."""
      
        if word in self.exception_dict:
            return self.exception_dict[word]
        
        if rule_stem in self.root_forms.values():
            return rule_stem
        if stat_stem in self.root_forms.values():
            return stat_stem
        
        word_suffix = word[min(len(rule_stem), len(stat_stem)):]
        if any(word_suffix.endswith(complex_suffix) for complex_suffix in ['తున్నాను', 'తున్నాడు', 'తున్నారు']):
      
            return rule_stem if len(rule_stem) >= self.min_stem_length else stat_stem
        
        if word.endswith(('లు', 'డు', 'ము', 'వు')):
            return stat_stem if len(stat_stem) >= self.min_stem_length else rule_stem
        
        if rule_stem[-1] in 'ాిీుూెేైొోౌ' and not stat_stem[-1] in 'ాిీుూెేైొోౌ':
            return rule_stem
        if stat_stem[-1] in 'ాిీుూెేైొోౌ' and not rule_stem[-1] in 'ాిీుూెేైొోౌ':
            return stat_stem
        
        if 'trigram' in self.ngram_models and len(rule_stem) >= 3 and len(stat_stem) >= 3:
            rule_ngram_score = sum(self.ngram_models['trigram'].get(tuple(rule_stem[i:i+3]), 0) for i in range(len(rule_stem)-2))
            stat_ngram_score = sum(self.ngram_models['trigram'].get(tuple(stat_stem[i:i+3]), 0) for i in range(len(stat_stem)-2))
            
            if rule_ngram_score > stat_ngram_score * 1.5:
                return rule_stem
            if stat_ngram_score > rule_ngram_score * 1.5:
                return stat_stem
        
        rule_quality = self.compute_stem_quality(word, rule_stem) * self.rule_weight
        stat_quality = self.compute_stem_quality(word, stat_stem) * self.stat_weight
        
        return rule_stem if rule_quality >= stat_quality else stat_stem

    def hybrid_stem(self, word):
        if not isinstance(word, str):
            return word
            
        if word in self.exception_dict:
            return self.exception_dict[word]
            
        rule_stem = self.rule_based_stem(word)
        stat_stem = self.statistical_stem(word)
        
        if rule_stem == stat_stem:
            return self.apply_morphological_rules(word, rule_stem)
        
        word_length = len(word)
        has_complex_suffix = any(word.endswith(s) for s in [suffix for suffix in self.suffixes if len(suffix) > 3])
        
        rule_weight = self.rule_weight * (1.2 if has_complex_suffix else 1.0)
        stat_weight = self.stat_weight * (1.2 if word_length > 8 else 1.0)
        
        if word_length <= 5:
            rule_weight *= 1.3
        
        rule_quality = self.compute_stem_quality(word, rule_stem) * rule_weight
        stat_quality = self.compute_stem_quality(word, stat_stem) * stat_weight
        
        if rule_quality > stat_quality * 1.3:
            return self.apply_morphological_rules(word, rule_stem)
        elif stat_quality > rule_quality * 1.3:
            return self.apply_morphological_rules(word, stat_stem)
        
        patterns = self.detect_telugu_patterns(word)
        if patterns.get('verb_present') or patterns.get('verb_past') or patterns.get('verb_future'):
           
            rule_weight *= 1.3
        elif patterns.get('plural_noun') or patterns.get('case_markers'):
            
            stat_weight *= 1.2
        
        return self.resolve_stem_disagreement(word, rule_stem, stat_stem)
    
    def find_systematic_errors(self, error_df):
        """Find patterns in errors to automatically generate new exceptions."""
        error_patterns = defaultdict(list)
        
        for _, row in error_df.iterrows():
            word = row['word']
            true_stem = row['true_stem']
            pred_stem = row['predicted_stem']
            
            for i in range(1, min(6, len(word))):
                ending = word[-i:]
                error_patterns[ending].append((word, true_stem, pred_stem))
        
        systematic_errors = {}
        for ending, errors in error_patterns.items():
            if len(errors) >= 3:  
                pred_true_pairs = [(e[2], e[1]) for e in errors]
                if len(set(pred_true_pairs)) <= 2: 
                    for word, true_stem, _ in errors:
                        systematic_errors[word] = true_stem
        
        return systematic_errors
        
    def compute_stem_quality(self, word, stem):
        """Compute a quality score for a proposed stem."""
        if not isinstance(stem, str) or not stem:
            return 0.0
            
        morph_score = 1.0 if self.is_valid_stem(stem) else 0.4
        
        ngram_score = 0.0
        if len(stem) >= 3 and 'trigram' in self.ngram_models:
            trigrams = list(ngrams(stem, 3))
            weighted_sum = 0
            for i, trigram in enumerate(trigrams):
                position_weight = 1.0 + (i / (len(trigrams) or 1)) * 0.7
                weighted_sum += self.ngram_models['trigram'].get(trigram, 0) * position_weight
            ngram_score = min(weighted_sum / (10 * (len(trigrams) or 1)), 1.0)
        
        if len(stem) < 3:
            length_ratio = 0.5 * (len(stem) / len(word))
        else:
            length_ratio = min(len(stem) / len(word), 0.8)
        
        struct_score = 1.0 if self.check_vowel_harmony(stem) else 0.6
        
        root_bonus = 1.3 if stem in self.root_forms.values() else 1.0
        
        ending_bonus = 1.0
        if stem.endswith(('ు', 'ి', 'ా', 'క', 'గ', 'చ', 'జ', 'ట', 'డ', 'త', 'ద', 'న', 'ప', 'బ', 'మ', 'య', 'ర', 'ల', 'వ', 'స')):
            ending_bonus = 1.2
         
        quality_score = (morph_score * 0.35 + 
                        ngram_score * 0.25 + 
                        length_ratio * 0.15 + 
                        struct_score * 0.15) * root_bonus * ending_bonus
                        
        return quality_score


    def evaluate(self, test_data):
        prediction_methods = {
            'rule_based': self.rule_based_stem,
            'statistical': self.statistical_stem,
            'hybrid': self.hybrid_stem
        }
        
        predictions = {method: [] for method in prediction_methods}
        for word, true_stem in tqdm(test_data, desc="Evaluating"):
            for method, predictor in prediction_methods.items():
                predictions[method].append(predictor(word))
        
        y_true = [true_stem for _, true_stem in test_data]
        
        accuracies = {method: np.mean([pred == true for pred, true in zip(preds, y_true)]) * 100 
                      for method, preds in predictions.items()}
        
        reports = {}
        error_analysis = {}
        
        for method, preds in predictions.items():
            try:
                reports[method] = pd.DataFrame(classification_report(y_true, preds, output_dict=True, zero_division=0)).T
                
                errors = [(word, pred, true) for (word, _), pred, true in zip(test_data, preds, y_true) if pred != true]
                error_by_length = defaultdict(int)
                for word, _, _ in errors:
                    error_by_length[len(word)] += 1
                
                error_analysis[method] = {
                    'total_errors': len(errors),
                    'error_rate': len(errors) / len(test_data) * 100,
                    'errors_by_length': dict(error_by_length),
                    'avg_error_word_length': np.mean([len(word) for word, _, _ in errors]) if errors else 0
                }
                
            except ValueError as e:
                print(f"Warning: Could not generate report for {method} due to: {e}")
                reports[method] = None
                error_analysis[method] = {'error': str(e)}
        
        return accuracies, reports, error_analysis

    def tune_parameters(self, validation_data):
        print("Tuning stemmer parameters...")
        best_accuracy = 0
        best_hybrid_advantage = 0
        best_params = {
            'rule_weight': self.rule_weight,
            'stat_weight': self.stat_weight,
            'min_stem_length': self.min_stem_length,
            'max_suffix_length': self.max_suffix_length
        }
        
        param_grid = {
        'rule_weight': [0.75, 0.85, 0.95, 1.0, 1.05, 1.15, 1.25],
        'stat_weight': [0.75, 0.85, 0.95, 1.0, 1.05, 1.15, 1.25],
        'min_stem_length': [2, 3],
        'max_suffix_length': [5, 6, 7, 8]
        }
        
        sample_size = min(500, len(validation_data))
        sampled_validation = validation_data[:sample_size]
        
        for rule_w in param_grid['rule_weight']:
            for stat_w in param_grid['stat_weight']:
                for min_len in param_grid['min_stem_length']:
                    for max_suffix in param_grid['max_suffix_length']:
                        self.rule_weight = rule_w
                        self.stat_weight = stat_w
                        self.min_stem_length = min_len
                        self.max_suffix_length = max_suffix
                        
                        hybrid_correct = 0
                        rule_correct = 0
                        stat_correct = 0
                        
                        for word, true_stem in sampled_validation:
                            if self.hybrid_stem(word) == true_stem:
                                hybrid_correct += 1
                            if self.rule_based_stem(word) == true_stem:
                                rule_correct += 1
                            if self.statistical_stem(word) == true_stem:
                                stat_correct += 1
                        
                        hybrid_accuracy = hybrid_correct / sample_size * 100
                        rule_accuracy = rule_correct / sample_size * 100
                        stat_accuracy = stat_correct / sample_size * 100
                        
                        best_individual = max(rule_accuracy, stat_accuracy)
                        hybrid_advantage = hybrid_accuracy - best_individual
                        
                        if hybrid_accuracy > best_accuracy:
                            best_accuracy = hybrid_accuracy
                            best_hybrid_advantage = hybrid_advantage
                            best_params = {
                                'rule_weight': rule_w,
                                'stat_weight': stat_w,
                                'min_stem_length': min_len,
                                'max_suffix_length': max_suffix
                            }
        
        self.rule_weight = best_params['rule_weight']
        self.stat_weight = best_params['stat_weight']
        self.min_stem_length = best_params['min_stem_length']
        self.max_suffix_length = best_params['max_suffix_length']
        
        print(f"Best parameters found: {best_params} with accuracy: {best_accuracy:.2f}% (advantage: +{best_hybrid_advantage:.2f}%)")
        return best_params, best_accuracy



def load_datasets(base_dir):
    try:
        annotated_df = pd.read_excel(os.path.join(base_dir, "training_data.xlsx"))
        test_df = pd.read_excel(os.path.join(base_dir, "test_set.xlsx"))
        
        corpus_file = os.path.join(base_dir, "processed_telugu_corpus.txt")
        if os.path.exists(corpus_file):
            with open(corpus_file, 'r', encoding='utf-8') as f:
                words = (word.strip() for word in f.read().split())
        else:
            words = annotated_df['word'].tolist()
        
        print(f"Loaded {len(annotated_df)} training samples, {len(test_df)} test samples")
        return words, annotated_df, test_df
    
    except Exception as e:
        print(f"Error loading datasets: {e}")
        return [], pd.DataFrame(), pd.DataFrame()

def split_data(data_df, test_size=0.2):
    indices = np.random.permutation(len(data_df))
    test_count = int(len(data_df) * test_size)
    test_indices = indices[:test_count]
    train_indices = indices[test_count:]
    
    train_df = data_df.iloc[train_indices]
    val_df = data_df.iloc[test_indices]
    
    return train_df, val_df


def main():
    stemmer = TeluguStemmer()
    print("Initializing Telugu Stemmer...")
    
    words, annotated_df, test_df = load_datasets(BASE_DIR)
    
    train_df, val_df = split_data(annotated_df, test_size=0.15)
    
    test_data = list(zip(test_df['word'], test_df['stem']))
    train_words, train_stems = train_df['word'].tolist(), train_df['stem'].tolist()
    val_words, val_stems = val_df['word'].tolist(), val_df['stem'].tolist()
    
    training_data = list(zip(train_words, train_stems))
    validation_data = list(zip(val_words, val_stems))
    
    stemmer.training_data = training_data

    print("\nInitial training...")
    stemmer.load_exceptions(os.path.join(BASE_DIR, "telugu_exceptions.xlsx"))
    stemmer.train_ngram_models(words, sample_size=800000) 
    
    print("\nParameter tuning...")
    stemmer.tune_parameters(validation_data[:800])  
    
    word_endings = Counter()
    for word, _ in training_data:
        for i in range(1, min(6, len(word))):
            word_endings[word[-i:]] += 1
    
    common_suffixes = [suffix for suffix, count in word_endings.most_common(50) 
                      if count > 5 and suffix not in stemmer.suffixes]
    
    stemmer.suffixes.extend(common_suffixes)
    stemmer.suffixes = sorted(stemmer.suffixes, key=len, reverse=True)

    print("\nInitial evaluation...")
    accuracies, reports, error_analysis = stemmer.evaluate(test_data)
    
    print("\nInitial Accuracy Scores:")
    for method, acc in accuracies.items():
        print(f"{method.title()}: {acc:.2f}%")
        print(f"  Errors: {error_analysis[method]['total_errors']}, Rate: {error_analysis[method]['error_rate']:.2f}%")

    print("\nPerforming iterative improvement...")
    max_iterations = 5  
    prev_hybrid_accuracy = accuracies['hybrid']
    improvement_threshold = 0.05  
    
    for iteration in range(max_iterations):
        print(f"\nIteration {iteration + 1}")
        
        hybrid_errors = stemmer.analyze_errors(test_data)
        
        rule_errors = []
        for word, true_stem in test_data:
            if stemmer.rule_based_stem(word) != true_stem:
                rule_errors.append((word, true_stem))
        
        stat_errors = []
        for word, true_stem in test_data:
            if stemmer.statistical_stem(word) != true_stem:
                stat_errors.append((word, true_stem))
        
        hybrid_error_tuples = [(row['word'], row['true_stem']) for _, row in hybrid_errors.iterrows()]
        
        all_errors = hybrid_error_tuples + rule_errors[:20] + stat_errors[:20]
        
        seen = set()
        unique_errors = []
        for item in all_errors:
            if item not in seen:
                seen.add(item)
                unique_errors.append(item)
        
        error_words = [word for word, _ in unique_errors]
        error_true_stems = [true_stem for _, true_stem in unique_errors]
        temp_data = list(zip(error_words, error_true_stems))
        error_df = stemmer.analyze_errors(temp_data)
        
        current_error_rate = len(hybrid_errors)/len(test_data)
        print(f"Found {len(hybrid_errors)} hybrid errors, {len(rule_errors)} rule errors, {len(stat_errors)} statistical errors")
        print(f"Combined unique errors: {len(error_df)}")
        print(f"Current hybrid error rate: {current_error_rate*100:.2f}%")
        
        if len(hybrid_errors) == 0 or (iteration > 0 and accuracies['hybrid'] - prev_hybrid_accuracy < improvement_threshold):
            print("No significant improvement, stopping iterations.")
            break
            
        prev_hybrid_accuracy = accuracies['hybrid']
        
        stemmer.update_from_errors(error_df.sample(min(75, len(error_df))))
        
        if iteration % 2 == 1:
            stemmer.tune_parameters(validation_data[:500])
        
        new_accuracies, new_reports, new_error_analysis = stemmer.evaluate(test_data)
        
        improvement = {k: new_accuracies[k] - accuracies[k] for k in accuracies}
        
        print("\nAccuracy Improvements:")
        for method, imp in improvement.items():
            print(f"{method.title()}: {imp:+.2f}%")
            
        accuracies, reports, error_analysis = new_accuracies, new_reports, new_error_analysis
    
    print("\nFinal Accuracy Scores:")
    for method, acc in accuracies.items():
        print(f"{method.title()}: {acc:.2f}%")
        print(f"  Errors: {error_analysis[method]['total_errors']}, Rate: {error_analysis[method]['error_rate']:.2f}%")
    
   
    
    model_file = os.path.join(BASE_DIR, "telugu_stemmer_model.pkl")
    try:
        with open(model_file, 'wb') as f:
            pickle.dump(stemmer, f)
        print(f"\nModel saved to {model_file}")
    except Exception as e:
        print(f"Error saving model: {e}")
    
    # NEW CODE: Print some example predictions
    print("\n--- Sample Stemming Results ---")
    
    # Use some words from test data as examples
    if test_data:
        print("\nTest Data Examples:")
        for i, (word, true_stem) in enumerate(test_data[:10]):  # Print first 10 examples
            predicted_stem = stemmer.hybrid_stem(word)
            print(f"Word: {word}")
            print(f"True stem: {true_stem}")
            print(f"Predicted stem: {predicted_stem}")
            print(f"Match: {'✓' if predicted_stem == true_stem else '✗'}")
            print("-" * 30)
    
    # You can also add some custom examples if you have specific words to test
    custom_examples = [
        "రాజులు", "చేస్తున్నాను", "పిల్లలకు", "పుస్తకములు", "వచ్చినది"
    ]
    
    print("\nCustom Examples:")
    for word in custom_examples:
        rule_stem = stemmer.rule_based_stem(word)
        stat_stem = stemmer.statistical_stem(word)
        hybrid_stem = stemmer.hybrid_stem(word)
        
        print(f"Word: {word}")
        print(f"Rule-based stem: {rule_stem}")
        print(f"Statistical stem: {stat_stem}")
        print(f"Hybrid stem: {hybrid_stem}")
        print("-" * 30)

    # Print words with different results between methods
    print("\nWords with different stemming results:")
    divergent_results = []
    for word, true_stem in test_data:
        rule_stem = stemmer.rule_based_stem(word)
        stat_stem = stemmer.statistical_stem(word)
        hybrid_stem = stemmer.hybrid_stem(word)
        
        if rule_stem != stat_stem or rule_stem != hybrid_stem or stat_stem != hybrid_stem:
            divergent_results.append((word, true_stem, rule_stem, stat_stem, hybrid_stem))
    
    for i, (word, true_stem, rule_stem, stat_stem, hybrid_stem) in enumerate(divergent_results[:5]):
        print(f"Word: {word}")
        print(f"True stem: {true_stem}")
        print(f"Rule-based: {rule_stem}")
        print(f"Statistical: {stat_stem}")
        print(f"Hybrid: {hybrid_stem}")
        def get_best_method(rule_stem, stat_stem, hybrid_stem, true_stem):
            stems = {'rule': rule_stem, 'statistical': stat_stem, 'hybrid': hybrid_stem}
            return max(['Rule', 'Statistical', 'Hybrid'], 
                    key=lambda m: 1 if stems[m.lower()] == true_stem else 0)

        print(f"Best method: {get_best_method(rule_stem, stat_stem, hybrid_stem, true_stem)}")
        print("-" * 30)


if __name__ == "__main__":
    main()

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/bathulahoneypriya/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Initializing Telugu Stemmer...
Loaded 2867 training samples, 200 test samples

Initial training...
Loaded 344 exceptions
Training n-gram models...


Training models: 100%|██████████| 3/3 [00:05<00:00,  1.95s/it]


N-gram models trained successfully

Parameter tuning...
Tuning stemmer parameters...
Best parameters found: {'rule_weight': 1.05, 'stat_weight': 0.75, 'min_stem_length': 3, 'max_suffix_length': 7} with accuracy: 22.33% (advantage: +1.86%)

Initial evaluation...


Evaluating: 100%|██████████| 200/200 [00:00<00:00, 1229.83it/s]


Initial Accuracy Scores:
Rule_Based: 59.50%
  Errors: 81, Rate: 40.50%
Statistical: 37.00%
  Errors: 126, Rate: 63.00%
Hybrid: 57.50%
  Errors: 85, Rate: 42.50%

Performing iterative improvement...

Iteration 1





Found 85 hybrid errors, 81 rule errors, 126 statistical errors
Combined unique errors: 82
Current hybrid error rate: 42.50%
Added 6 high confidence fixes
Added 42 medium confidence fixes
Added 13 general exceptions
Updated weights - Rule: 1.05, Stat: 0.75
Added 0 systematic error fixes


Evaluating: 100%|██████████| 200/200 [00:00<00:00, 1427.33it/s]



Accuracy Improvements:
Rule_Based: +6.00%
Statistical: +8.50%
Hybrid: +8.50%

Iteration 2
Found 68 hybrid errors, 69 rule errors, 109 statistical errors
Combined unique errors: 65
Current hybrid error rate: 34.00%
Added 1 high confidence fixes
Added 41 medium confidence fixes
Added 13 general exceptions
Updated weights - Rule: 1.05, Stat: 0.75
Added 0 systematic error fixes
Tuning stemmer parameters...
Best parameters found: {'rule_weight': 1.25, 'stat_weight': 0.75, 'min_stem_length': 3, 'max_suffix_length': 5} with accuracy: 20.47% (advantage: +0.00%)


Evaluating: 100%|██████████| 200/200 [00:00<00:00, 1616.92it/s]


Accuracy Improvements:
Rule_Based: +6.50%
Statistical: +7.00%
Hybrid: +6.50%

Iteration 3





Found 55 hybrid errors, 56 rule errors, 95 statistical errors
Combined unique errors: 53
Current hybrid error rate: 27.50%
Added 1 high confidence fixes
Added 34 medium confidence fixes
Added 13 general exceptions
Updated weights - Rule: 1.25, Stat: 0.75
Added 0 systematic error fixes


Evaluating: 100%|██████████| 200/200 [00:00<00:00, 1867.16it/s]


Accuracy Improvements:
Rule_Based: +6.50%
Statistical: +6.00%
Hybrid: +6.50%

Iteration 4





Found 42 hybrid errors, 43 rule errors, 83 statistical errors
Combined unique errors: 41
Current hybrid error rate: 21.00%
Added 0 high confidence fixes
Added 24 medium confidence fixes
Added 13 general exceptions
Updated weights - Rule: 1.25, Stat: 0.75
Added 0 systematic error fixes
Tuning stemmer parameters...
Best parameters found: {'rule_weight': 1.25, 'stat_weight': 0.75, 'min_stem_length': 3, 'max_suffix_length': 5} with accuracy: 20.47% (advantage: +0.00%)


Evaluating: 100%|██████████| 200/200 [00:00<00:00, 2153.46it/s]


Accuracy Improvements:
Rule_Based: +6.50%
Statistical: +6.50%
Hybrid: +6.50%

Iteration 5





Found 29 hybrid errors, 30 rule errors, 70 statistical errors
Combined unique errors: 28
Current hybrid error rate: 14.50%
Added 0 high confidence fixes
Added 14 medium confidence fixes
Added 13 general exceptions
Updated weights - Rule: 1.25, Stat: 0.75
Added 0 systematic error fixes


Evaluating: 100%|██████████| 200/200 [00:00<00:00, 2612.57it/s]


Accuracy Improvements:
Rule_Based: +6.50%
Statistical: +6.50%
Hybrid: +6.50%

Final Accuracy Scores:
Rule_Based: 91.50%
  Errors: 17, Rate: 8.50%
Statistical: 71.50%
  Errors: 57, Rate: 28.50%
Hybrid: 92.00%
  Errors: 16, Rate: 8.00%

Model saved to /Users/bathulahoneypriya/Documents/graph/telugu/telugu_stemmer_model.pkl

--- Sample Stemming Results ---

Test Data Examples:
Word: పక్షులు
True stem: పక్షు
Predicted stem: పక్షు
Match: ✓
------------------------------
Word: చేపలు
True stem: చేప
Predicted stem: చేప
Match: ✓
------------------------------
Word: అడిగాడు
True stem: అడిగ
Predicted stem: అడిగ
Match: ✓
------------------------------
Word: పక్షులు
True stem: పక్షు
Predicted stem: పక్షు
Match: ✓
------------------------------
Word: కారణాలు
True stem: కారణా
Predicted stem: కారణా
Match: ✓
------------------------------
Word: పండితులు
True stem: పండితు
Predicted stem: పండితు
Match: ✓
------------------------------
Word: వాదనలు
True stem: వాదన
Predicted stem: వాదన
Match: ✓
----------




Word: పక్షులు
True stem: పక్షు
Rule-based: పక్షు
Statistical: పక్షుల
Hybrid: పక్షు
Best method: Rule
------------------------------
Word: చేపలు
True stem: చేప
Rule-based: చేప
Statistical: చేపల
Hybrid: చేప
Best method: Rule
------------------------------
Word: పక్షులు
True stem: పక్షు
Rule-based: పక్షు
Statistical: పక్షుల
Hybrid: పక్షు
Best method: Rule
------------------------------
Word: కారణాలు
True stem: కారణా
Rule-based: కారణా
Statistical: కారణాల
Hybrid: కారణా
Best method: Rule
------------------------------
Word: వాదనలు
True stem: వాదన
Rule-based: వాదన
Statistical: వాదనల
Hybrid: వాదన
Best method: Rule
------------------------------
