# Hybrid TF-IDF + BERT Interest Classification Model
## Code ∙ Version 2 

In [3]:
import pandas as pd
import numpy as np
import re
import os
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from transformers import pipeline
import torch
import logging
import time
from typing import List, Dict, Tuple, Union, Optional

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Define interest categories
INTEREST_CATEGORIES = ["Music", "Food", "Sports", "Technology", "Arts", "Travel", "Education"]

class InterestClassifier:
    """
    Hybrid Interest Classification model that combines TF-IDF with BERT zero-shot classification
    """
    def __init__(self, 
                 model_path: Optional[str] = None,
                 alpha: float = 0.6, 
                 threshold: float = 0.5,
                 bert_model_name: str = 'facebook/bart-large-mnli',
                 use_gpu: bool = torch.cuda.is_available()):
        """
        Initialize the hybrid classifier
        
        Args:
            model_path: Path to a saved model (if None, a new model will be created)
            alpha: Weight for TF-IDF model (1-alpha for BERT)
            threshold: Classification threshold for final predictions
            bert_model_name: Name of the BERT model to use
            use_gpu: Whether to use GPU for BERT inference
        """
        self.alpha = alpha
        self.threshold = threshold
        self.bert_model_name = bert_model_name
        self.use_gpu = use_gpu
        
        # Initialize models as None
        self.tfidf_pipeline = None
        self.mlb = None
        self.bert_classifier = None
        
        # Load the model if path is provided
        if model_path and os.path.exists(model_path):
            self.load_model(model_path)
        
        # Initialize BERT model
        self._init_bert_classifier()
    
    def _improved_preprocess_text(self, text: str) -> str:
        """
        Enhanced text preprocessing that better preserves domain-specific indicators
        
        Args:
            text: Input text to preprocess
            
        Returns:
            Preprocessed text
        """
        # Handle potential NaN values
        if pd.isna(text):
            return ""
        
        # Convert to lowercase
        text = text.lower()
        
        # Remove special characters while preserving important separators
        text = re.sub(r'[^\w\s|-]', ' ', text)
        
        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)
        
        # Define domain terms dictionary
        domain_terms = {
            'music': ['music', 'guitar', 'band', 'concert', 'gig', 'sing', 'song', 'play music', 'musician'],
            'food': ['food', 'cook', 'cuisine', 'recipe', 'restaurant', 'eat', 'culinary', 'bake', 'chef'],
            'sports': ['sport', 'run', 'gym', 'fitness', 'workout', 'exercise', 'athletic', 'training'],
            'arts': ['art', 'paint', 'draw', 'museum', 'gallery', 'exhibit', 'creative', 'design'],
            'technology': ['tech', 'code', 'program', 'software', 'developer', 'computer', 'app', 'digital'],
            'education': ['education', 'learn', 'course', 'class', 'study', 'book', 'read', 'academic'],
            'travel': ['travel', 'trip', 'hike', 'explore', 'tour', 'visit', 'journey', 'destination']
        }
        
        # Check for domain terms and emphasize them
        modified_text = text
        for category, terms in domain_terms.items():
            for term in terms:
                if term in text:
                    # Add the category name explicitly if a related term is found
                    modified_text += f" {category} {category} {term} {term}"
        
        # Split on common separators but preserve the important phrases
        parts = []
        for part in re.split(r'\s*\|\s*', modified_text):
            # Remove numbers (but keep words with numbers like "web3")
            part = re.sub(r'\b\d+\b', '', part)
            parts.append(part)
        
        # Define a more focused stopwords list
        core_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'the', 'a', 'an', 'and', 'but', 
                          'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 
                          'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 
                          'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 
                          'under', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were'}
        
        # Process each part and filter stopwords
        processed_parts = []
        for part in parts:
            words = part.split()
            filtered_words = [word for word in words if word not in core_stopwords]
            
            if filtered_words:
                processed_parts.append(' '.join(filtered_words))
        
        # Join the processed parts back
        processed_text = ' '.join(processed_parts)
        
        return processed_text.strip()
    
    def _init_bert_classifier(self):
        """Initialize the BERT zero-shot classifier"""
        try:
            logger.info(f"Initializing BERT zero-shot classifier with model: {self.bert_model_name}")
            device = 0 if self.use_gpu and torch.cuda.is_available() else -1
            self.bert_classifier = pipeline('zero-shot-classification', 
                                           model=self.bert_model_name, 
                                           device=device)
            logger.info("BERT classifier successfully initialized")
        except Exception as e:
            logger.error(f"Failed to initialize BERT classifier: {e}")
            logger.warning("Proceeding without BERT - will use TF-IDF only")
            self.bert_classifier = None
    
    def train(self, 
              df: pd.DataFrame, 
              text_column: str = 'survey_answer', 
              labels_column: str = 'labels_list',
              test_size: float = 0.2):
        """
        Train the TF-IDF + Logistic Regression model
        
        Args:
            df: DataFrame containing survey responses and labels
            text_column: Column name containing the survey responses
            labels_column: Column name containing the labels
            test_size: Proportion of data to use for testing
        
        Returns:
            Evaluation metrics on test set
        """
        logger.info("Starting model training...")
        
        # Prepare labels
        if isinstance(df[labels_column].iloc[0], str):
            logger.info("Converting labels from string to list...")
            # Convert string representation of lists to actual lists
            df[labels_column] = df[labels_column].str.strip('[]').str.split(',')
            # Clean up any extra quotes or spaces
            df[labels_column] = df[labels_column].apply(lambda x: [item.strip().strip("'\"") for item in x])
        
        # Preprocess text
        logger.info("Preprocessing text data...")
        df['processed_text'] = df[text_column].apply(self._improved_preprocess_text)
        
        # Initialize MultiLabelBinarizer
        self.mlb = MultiLabelBinarizer(classes=INTEREST_CATEGORIES)
        y = self.mlb.fit_transform(df[labels_column])
        logger.info(f"Target shape: {y.shape}")
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            df['processed_text'], y, test_size=test_size, random_state=42, shuffle=True
        )
        logger.info(f"Training set: {X_train.shape[0]} samples, Test set: {X_test.shape[0]} samples")
        
        # Create TF-IDF pipeline
        logger.info("Creating and training TF-IDF pipeline...")
        tfidf_vectorizer = TfidfVectorizer(
            max_features=3000,
            min_df=2,
            max_df=0.9,
            ngram_range=(1, 3),
            sublinear_tf=True
        )
        
        lr_clf = LogisticRegression(
            C=0.5,
            max_iter=1000,
            class_weight='balanced',
            solver='liblinear',
            penalty='l2'
        )
        
        multi_lr = MultiOutputClassifier(lr_clf)
        
        self.tfidf_pipeline = Pipeline([
            ('tfidf', tfidf_vectorizer),
            ('classifier', multi_lr)
        ])
        
        # Train the pipeline
        self.tfidf_pipeline.fit(X_train, y_train)
        logger.info("TF-IDF pipeline trained successfully")
        
        # Evaluate on test set
        logger.info("Evaluating model on test set...")
        y_pred = self.tfidf_pipeline.predict(X_test)
        
        # Calculate metrics
        from sklearn.metrics import hamming_loss, f1_score, precision_score, recall_score
        h_loss = hamming_loss(y_test, y_pred)
        micro_f1 = f1_score(y_test, y_pred, average='micro')
        macro_f1 = f1_score(y_test, y_pred, average='macro')
        
        logger.info(f"Hamming Loss: {h_loss:.4f}")
        logger.info(f"Micro F1 Score: {micro_f1:.4f}")
        logger.info(f"Macro F1 Score: {macro_f1:.4f}")
        
        return {
            'hamming_loss': h_loss,
            'micro_f1': micro_f1,
            'macro_f1': macro_f1
        }
    
    def get_tfidf_predictions(self, text: str) -> Dict[str, float]:
        """
        Get predictions from TF-IDF model with confidence scores
        
        Args:
            text: The input text to classify
            
        Returns:
            Dictionary of label -> score
        """
        if self.tfidf_pipeline is None:
            raise ValueError("TF-IDF model is not trained yet. Call train() first.")
            
        # Preprocess text
        processed_text = self._improved_preprocess_text(text)
        
        # Get raw prediction probabilities
        y_proba = self.tfidf_pipeline.predict_proba([processed_text])
        
        # Convert to dictionary of label -> score
        scores = {}
        for i, label in enumerate(self.mlb.classes_):
            # For MultiOutputClassifier, each element of y_proba is a list of arrays
            # Each array is for one label and has 2 values: [prob_for_0, prob_for_1]
            scores[label] = y_proba[i][0][1]  # Get probability of positive class
        
        return scores
    
    def get_bert_predictions(self, text: str) -> Dict[str, float]:
        """
        Get predictions from BERT model
        
        Args:
            text: The input text to classify
            
        Returns:
            Dictionary of label -> score
        """
        if self.bert_classifier is None:
            logger.warning("BERT classifier is not available, returning empty scores")
            return {label: 0.0 for label in INTEREST_CATEGORIES}
            
        try:
            # Use the BERT zero-shot classifier
            result = self.bert_classifier(text, INTEREST_CATEGORIES, multi_label=True)
            
            # Convert to dictionary of label -> score
            scores = dict(zip(result['labels'], result['scores']))
            
            # Ensure all categories are present (BERT may return in different order)
            for category in INTEREST_CATEGORIES:
                if category not in scores:
                    scores[category] = 0.0
                    
            return scores
            
        except Exception as e:
            logger.error(f"Error in BERT prediction: {e}")
            return {label: 0.0 for label in INTEREST_CATEGORIES}
    
    def predict(self, 
                text: str, 
                alpha: Optional[float] = None,
                threshold: Optional[float] = None,
                return_scores: bool = False) -> Union[List[str], Dict]:
        """
        Combine TF-IDF and BERT predictions using weighted average
        
        Args:
            text: The input text to classify
            alpha: Weight for TF-IDF predictions (1-alpha for BERT), uses self.alpha if None
            threshold: Threshold for classification, uses self.threshold if None
            return_scores: Whether to return scores along with labels
        
        Returns:
            Either a list of predicted labels or a dictionary with labels and scores
        """
        if self.tfidf_pipeline is None:
            raise ValueError("Model is not trained yet. Call train() first.")
            
        # Use instance values if not provided
        alpha = alpha if alpha is not None else self.alpha
        threshold = threshold if threshold is not None else self.threshold
        
        # Time the predictions
        start_time = time.time()
        
        # Get TF-IDF predictions
        tfidf_scores = self.get_tfidf_predictions(text)
        tfidf_time = time.time() - start_time
        
        # Get BERT predictions if available
        bert_time_start = time.time()
        if self.bert_classifier is not None:
            bert_scores = self.get_bert_predictions(text)
            use_bert = True
        else:
            bert_scores = {category: 0.0 for category in INTEREST_CATEGORIES}
            use_bert = False
            logger.warning("BERT classifier not available, using TF-IDF only")
        bert_time = time.time() - bert_time_start
        
        # Combine predictions
        combined_scores = {}
        final_labels = []
        
        for category in INTEREST_CATEGORIES:
            # Get scores from both models
            tfidf_score = tfidf_scores.get(category, 0.0)
            bert_score = bert_scores.get(category, 0.0)
            
            # Weighted average (if using BERT)
            if use_bert:
                final_score = (alpha * tfidf_score) + ((1 - alpha) * bert_score)
            else:
                final_score = tfidf_score
                
            combined_scores[category] = final_score
            
            # Apply threshold
            if final_score >= threshold:
                final_labels.append(category)
        
        total_time = time.time() - start_time
        
        if return_scores:
            # Sort scores for easier interpretation
            sorted_scores = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
            
            return {
                'labels': final_labels,
                'scores': combined_scores,
                'sorted_scores': sorted_scores,
                'tfidf_scores': tfidf_scores,
                'bert_scores': bert_scores,
                'timing': {
                    'tfidf': tfidf_time,
                    'bert': bert_time,
                    'total': total_time
                },
                'alpha': alpha,
                'threshold': threshold,
                'using_bert': use_bert
            }
        
        return final_labels
    
    def save_model(self, path: str = "hybrid_interest_classifier.pkl"):
        """
        Save the model to disk
        
        Args:
            path: Path to save the model
        """
        if self.tfidf_pipeline is None:
            raise ValueError("Model is not trained yet. Call train() first.")
            
        # Note: We only save the TF-IDF pipeline and MLBinarizer
        # BERT will be re-initialized on load
        components = {
            'tfidf_pipeline': self.tfidf_pipeline,
            'mlb': self.mlb,
            'alpha': self.alpha,
            'threshold': self.threshold,
            'bert_model_name': self.bert_model_name,
            'interest_categories': INTEREST_CATEGORIES,
            'version': '1.0'
        }
        
        with open(path, 'wb') as f:
            pickle.dump(components, f)
            
        logger.info(f"Model saved to {path}")
    
    def load_model(self, path: str):
        """
        Load a saved model from disk
        
        Args:
            path: Path to the saved model
        """
        try:
            with open(path, 'rb') as f:
                components = pickle.load(f)
                
            self.tfidf_pipeline = components['tfidf_pipeline']
            self.mlb = components['mlb']
            self.alpha = components.get('alpha', 0.6)
            self.threshold = components.get('threshold', 0.5)
            self.bert_model_name = components.get('bert_model_name', 'facebook/bart-large-mnli')
            
            logger.info(f"Model loaded from {path}")
            
            # Re-initialize BERT classifier
            self._init_bert_classifier()
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise


# Example usage
def main():
    try:
        # Load dataset
        logger.info("Loading dataset: survey_interest_dataset_enhanced.csv")
        df = pd.read_csv('survey_interest_dataset_enhanced.csv')
        
        # Convert labels_list if it's a string representation
        if 'labels_list' in df.columns and isinstance(df['labels_list'].iloc[0], str):
            logger.info("Converting labels_list from string to list...")
            df['labels_list'] = df['labels_list'].str.strip('[]').str.split(',')
            df['labels_list'] = df['labels_list'].apply(lambda x: [item.strip().strip("'\"") for item in x])
        
        # Initialize classifier
        logger.info("Initializing classifier with alpha=0.6, threshold=0.5")
        classifier = InterestClassifier(alpha=0.6, threshold=0.5)
        
        # Train the model
        logger.info("Training the model...")
        metrics = classifier.train(df)
        logger.info(f"Training metrics: {metrics}")
        
        # Save the model
        model_path = "hybrid_interest_classifier.pkl"
        logger.info(f"Saving model to {model_path}")
        classifier.save_model(model_path)
        
        # Test on some examples
        test_examples = [
            "I love hiking in the mountains and trying local foods wherever I travel.",
            "I'm a software developer who plays guitar in a band on weekends.",
            "I spend most of my time reading books and attending online courses.",
            "I enjoy painting landscapes and visiting art museums when I travel."
        ]
        
        logger.info("Testing model on example inputs...")
        for example in test_examples:
            result = classifier.predict(example, return_scores=True)
            logger.info(f"\nExample: '{example}'")
            logger.info(f"Predicted interests: {result['labels']}")
            logger.info("Top interests by score:")
            for category, score in result['sorted_scores'][:3]:
                logger.info(f"  {category}: {score:.4f}")
                
        # Fine-tuning alpha parameter demo
        logger.info("\nFine-tuning alpha parameter:")
        example = "I work as a software developer and enjoy hiking on weekends"
        for alpha in [0.3, 0.5, 0.7, 0.9]:
            result = classifier.predict(example, alpha=alpha, return_scores=True)
            logger.info(f"\nAlpha = {alpha} (TF-IDF weight: {alpha}, BERT weight: {1-alpha})")
            logger.info(f"Predicted interests: {result['labels']}")
            logger.info("Top 3 scores:")
            for category, score in result['sorted_scores'][:3]:
                logger.info(f"  {category}: {score:.4f}")
        
        logger.info("Model training and evaluation completed successfully")
        
    except Exception as e:
        logger.error(f"Error in main function: {e}", exc_info=True)
        raise


if __name__ == "__main__":
    main()

2025-04-12 19:59:16,282 - INFO - Loading dataset: survey_interest_dataset_enhanced.csv
2025-04-12 19:59:16,306 - INFO - Converting labels_list from string to list...
2025-04-12 19:59:16,315 - INFO - Initializing classifier with alpha=0.6, threshold=0.5
2025-04-12 19:59:16,317 - INFO - Initializing BERT zero-shot classifier with model: facebook/bart-large-mnli






Device set to use cpu
2025-04-12 19:59:19,738 - INFO - BERT classifier successfully initialized
2025-04-12 19:59:19,740 - INFO - Training the model...
2025-04-12 19:59:19,742 - INFO - Starting model training...
2025-04-12 19:59:19,742 - INFO - Preprocessing text data...
2025-04-12 19:59:19,798 - INFO - Target shape: (3018, 7)
2025-04-12 19:59:19,806 - INFO - Training set: 2414 samples, Test set: 604 samples
2025-04-12 19:59:19,807 - INFO - Creating and training TF-IDF pipeline...
2025-04-12 19:59:19,901 - INFO - TF-IDF pipeline trained successfully
2025-04-12 19:59:19,901 - INFO - Evaluating model on test set...
2025-04-12 19:59:19,925 - INFO - Hamming Loss: 0.0116
2025-04-12 19:59:19,926 - INFO - Micro F1 Score: 0.9762
2025-04-12 19:59:19,927 - INFO - Macro F1 Score: 0.9759
2025-04-12 19:59:19,927 - INFO - Training metrics: {'hamming_loss': 0.011589403973509934, 'micro_f1': 0.9762251334303736, 'macro_f1': 0.9759012864191312}
2025-04-12 19:59:19,928 - INFO - Saving model to hybrid_int

### Hybrid Model Fine-Tuning Script Code
#### Optimized Model Tuning Code with Checkpoint Support 

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import f1_score, hamming_loss
import matplotlib.pyplot as plt
import time
import json
import os
import pickle
from tqdm import tqdm
import logging
import traceback

# Import the hybrid classifier
from hybrid_interest_classifier import InterestClassifier

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def tune_alpha_parameter(df, alphas=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 
                         n_splits=5, thresholds=[0.3, 0.4, 0.5, 0.6],
                         checkpoint_file='alpha_tuning_checkpoint.pkl',
                         resume=True):
    """
    Fine-tune the alpha parameter using cross-validation with checkpoint support
    
    Args:
        df: DataFrame containing the dataset
        alphas: List of alpha values to test
        n_splits: Number of CV splits
        thresholds: List of thresholds to test
        checkpoint_file: File to save/load checkpoints
        resume: Whether to resume from checkpoint if available
    
    Returns:
        Dictionary of results
    """
    logger.info("Starting alpha parameter tuning...")
    
    # Convert labels_list if it's a string representation
    if isinstance(df['labels_list'].iloc[0], str):
        df['labels_list'] = df['labels_list'].str.strip('[]').str.split(',')
        df['labels_list'] = df['labels_list'].apply(lambda x: [item.strip().strip("'\"") for item in x])
    
    # Initialize CrossValidator
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    # Check if checkpoint exists and we want to resume
    if os.path.exists(checkpoint_file) and resume:
        try:
            logger.info(f"Loading checkpoint from {checkpoint_file}")
            with open(checkpoint_file, 'rb') as f:
                checkpoint = pickle.load(f)
            
            results = checkpoint['results']
            start_fold = checkpoint['next_fold']
            start_alpha_idx = checkpoint['next_alpha_idx']
            start_threshold_idx = checkpoint['next_threshold_idx']
            
            logger.info(f"Resuming from fold {start_fold}, alpha_idx {start_alpha_idx}, threshold_idx {start_threshold_idx}")
            
            # Get all train/test splits to skip to the right fold
            all_splits = list(kf.split(df))
            
        except Exception as e:
            logger.error(f"Error loading checkpoint: {e}")
            logger.info("Starting from scratch...")
            results = {
                'params': [],
                'micro_f1': [],
                'macro_f1': [],
                'hamming': []
            }
            start_fold = 0
            start_alpha_idx = 0
            start_threshold_idx = 0
            all_splits = list(kf.split(df))
    else:
        # Initialize results
        results = {
            'params': [],
            'micro_f1': [],
            'macro_f1': [],
            'hamming': []
        }
        start_fold = 0
        start_alpha_idx = 0
        start_threshold_idx = 0
        all_splits = list(kf.split(df))
    
    # Run cross-validation with resume capability
    try:
        # Loop through folds
        for fold_idx in range(start_fold, n_splits):
            logger.info(f"Processing fold {fold_idx+1}/{n_splits}")
            
            # Get split
            train_idx, test_idx = all_splits[fold_idx]
            
            # Split data
            train_df = df.iloc[train_idx].copy()
            test_df = df.iloc[test_idx].copy()
            
            # Initialize and train classifier
            classifier = InterestClassifier(alpha=0.5, threshold=0.5)  # Default values
            classifier.train(train_df)
            
            # Test each combination of alpha and threshold
            for alpha_idx, alpha in enumerate(alphas[start_alpha_idx if fold_idx == start_fold else 0:]):
                current_alpha_idx = start_alpha_idx + alpha_idx if fold_idx == start_fold else alpha_idx
                
                for threshold_idx, threshold in enumerate(thresholds[start_threshold_idx if fold_idx == start_fold and alpha_idx == 0 else 0:]):
                    current_threshold_idx = start_threshold_idx + threshold_idx if fold_idx == start_fold and alpha_idx == 0 else threshold_idx
                    
                    logger.info(f"Testing alpha={alpha}, threshold={threshold}")
                    
                    try:
                        # Get predictions for test set
                        y_true = []
                        y_pred = []
                        
                        for idx, row in tqdm(test_df.iterrows(), total=len(test_df), 
                                            desc=f"Fold {fold_idx+1}, Alpha={alpha}, Threshold={threshold}"):
                            text = row['survey_answer']
                            true_labels = row['labels_list']
                            
                            # Get predictions with current parameters
                            # Add timeout or other safeguards if prediction takes too long
                            try:
                                pred_labels = classifier.predict(text, alpha=alpha, threshold=threshold)
                                
                                y_true.append(true_labels)
                                y_pred.append(pred_labels)
                            except Exception as e:
                                logger.error(f"Error predicting sample {idx}: {e}")
                                # Use an empty prediction in case of error
                                y_true.append(true_labels)
                                y_pred.append([])
                        
                        # Convert to multilabel format
                        from sklearn.preprocessing import MultiLabelBinarizer
                        mlb = MultiLabelBinarizer(classes=classifier.mlb.classes_)
                        y_true_bin = mlb.fit_transform(y_true)
                        y_pred_bin = mlb.transform(y_pred)
                        
                        # Calculate metrics
                        micro_f1 = f1_score(y_true_bin, y_pred_bin, average='micro')
                        macro_f1 = f1_score(y_true_bin, y_pred_bin, average='macro')
                        h_loss = hamming_loss(y_true_bin, y_pred_bin)
                        
                        # Store results
                        results['params'].append((fold_idx+1, alpha, threshold))
                        results['micro_f1'].append(micro_f1)
                        results['macro_f1'].append(macro_f1)
                        results['hamming'].append(h_loss)
                        
                        logger.info(f"Fold {fold_idx+1}, Alpha={alpha}, Threshold={threshold}: "
                                  f"Micro-F1={micro_f1:.4f}, Macro-F1={macro_f1:.4f}, "
                                  f"Hamming Loss={h_loss:.4f}")
                        
                        # Save checkpoint after each parameter evaluation
                        checkpoint = {
                            'results': results,
                            'next_fold': fold_idx,
                            'next_alpha_idx': current_alpha_idx,
                            'next_threshold_idx': current_threshold_idx + 1 if current_threshold_idx + 1 < len(thresholds) else 0,
                            'next_param_set': (
                                fold_idx,
                                current_alpha_idx + 1 if current_threshold_idx + 1 >= len(thresholds) else current_alpha_idx,
                                0 if current_threshold_idx + 1 >= len(thresholds) else current_threshold_idx + 1
                            )
                        }
                        
                        # If we've completed this alpha, advance to next alpha
                        if current_threshold_idx + 1 >= len(thresholds):
                            checkpoint['next_alpha_idx'] = current_alpha_idx + 1
                            checkpoint['next_threshold_idx'] = 0
                        
                        # If we've completed all alphas, advance to next fold
                        if current_alpha_idx + 1 >= len(alphas) and current_threshold_idx + 1 >= len(thresholds):
                            checkpoint['next_fold'] = fold_idx + 1
                            checkpoint['next_alpha_idx'] = 0
                            checkpoint['next_threshold_idx'] = 0
                        
                        with open(checkpoint_file, 'wb') as f:
                            pickle.dump(checkpoint, f)
                        
                    except KeyboardInterrupt:
                        logger.warning("KeyboardInterrupt detected. Saving checkpoint and exiting...")
                        # Save checkpoint before exiting
                        checkpoint = {
                            'results': results,
                            'next_fold': fold_idx,
                            'next_alpha_idx': current_alpha_idx,
                            'next_threshold_idx': current_threshold_idx,
                            'next_param_set': (fold_idx, current_alpha_idx, current_threshold_idx)
                        }
                        with open(checkpoint_file, 'wb') as f:
                            pickle.dump(checkpoint, f)
                        raise
                    except Exception as e:
                        logger.error(f"Error evaluating params alpha={alpha}, threshold={threshold}: {e}")
                        logger.error(traceback.format_exc())
                        continue
                
                # Reset threshold index for next alpha
                start_threshold_idx = 0
            
            # Reset alpha index for next fold
            start_alpha_idx = 0
    
    except KeyboardInterrupt:
        logger.warning("KeyboardInterrupt received. Saving progress and attempting to continue with analysis...")
        # We'll still try to produce meaningful results with what we have so far
    except Exception as e:
        logger.error(f"Error during tuning: {e}")
        logger.error(traceback.format_exc())
    
    # Analyze results even if we didn't complete all runs
    if len(results['params']) == 0:
        logger.error("No results collected. Cannot analyze.")
        return None
    
    # Aggregate results by parameters
    agg_results = {}
    for param in [(a, t) for a in alphas for t in thresholds]:
        alpha, threshold = param
        mask = [(fold, a, t) for fold, a, t in results['params'] if a == alpha and t == threshold]
        
        if not mask:  # Skip if we have no results for this parameter combination
            continue
        
        # Calculate mean and std for each metric
        micro_f1_values = [results['micro_f1'][results['params'].index(key)] for key in mask]
        macro_f1_values = [results['macro_f1'][results['params'].index(key)] for key in mask]
        hamming_values = [results['hamming'][results['params'].index(key)] for key in mask]
        
        agg_results[f"alpha={alpha},threshold={threshold}"] = {
            'micro_f1': {
                'mean': np.mean(micro_f1_values),
                'std': np.std(micro_f1_values)
            },
            'macro_f1': {
                'mean': np.mean(macro_f1_values),
                'std': np.std(macro_f1_values)
            },
            'hamming': {
                'mean': np.mean(hamming_values),
                'std': np.std(hamming_values)
            }
        }
    
    # Find best parameters
    best_micro_f1 = 0
    best_params_micro = None
    best_hamming = 1.0  # Lower is better for hamming loss
    best_params_hamming = None
    
    for param, metrics in agg_results.items():
        if metrics['micro_f1']['mean'] > best_micro_f1:
            best_micro_f1 = metrics['micro_f1']['mean']
            best_params_micro = param
        
        if metrics['hamming']['mean'] < best_hamming:
            best_hamming = metrics['hamming']['mean']
            best_params_hamming = param
    
    if best_params_micro:
        logger.info(f"Best parameters by Micro-F1: {best_params_micro}, score: {best_micro_f1:.4f}")
    if best_params_hamming:
        logger.info(f"Best parameters by Hamming Loss: {best_params_hamming}, score: {best_hamming:.4f}")
    
    # Save results
    with open('alpha_tuning_results.json', 'w') as f:
        json.dump({
            'aggregated': agg_results,
            'best_micro_f1': {
                'params': best_params_micro,
                'score': float(best_micro_f1) if best_params_micro else None
            },
            'best_hamming': {
                'params': best_params_hamming,
                'score': float(best_hamming) if best_params_hamming else None
            },
            'completion_status': {
                'completed_combinations': len(results['params']),
                'total_combinations': len(alphas) * len(thresholds) * n_splits,
                'percentage_complete': 100 * len(results['params']) / (len(alphas) * len(thresholds) * n_splits)
            }
        }, f, indent=2)
    
    # Plot results if we have enough data
    if agg_results:
        plt.figure(figsize=(15, 10))
        
        # Plot Micro-F1 by alpha for each threshold
        plt.subplot(2, 1, 1)
        for threshold in thresholds:
            x = []
            y = []
            yerr = []
            for alpha in alphas:
                key = f"alpha={alpha},threshold={threshold}"
                if key in agg_results:
                    x.append(alpha)
                    y.append(agg_results[key]['micro_f1']['mean'])
                    yerr.append(agg_results[key]['micro_f1']['std'])
            
            if x:  # Only plot if we have data
                plt.errorbar(x, y, yerr=yerr, marker='o', label=f'Threshold={threshold}')
        
        plt.xlabel('Alpha (TF-IDF Weight)')
        plt.ylabel('Micro F1-Score')
        plt.title('Micro F1-Score by Alpha and Threshold')
        plt.legend()
        plt.grid(True)
        
        # Plot Hamming Loss by alpha for each threshold
        plt.subplot(2, 1, 2)
        for threshold in thresholds:
            x = []
            y = []
            yerr = []
            for alpha in alphas:
                key = f"alpha={alpha},threshold={threshold}"
                if key in agg_results:
                    x.append(alpha)
                    y.append(agg_results[key]['hamming']['mean'])
                    yerr.append(agg_results[key]['hamming']['std'])
            
            if x:  # Only plot if we have data
                plt.errorbar(x, y, yerr=yerr, marker='o', label=f'Threshold={threshold}')
        
        plt.xlabel('Alpha (TF-IDF Weight)')
        plt.ylabel('Hamming Loss')
        plt.title('Hamming Loss by Alpha and Threshold')
        plt.legend()
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig('alpha_tuning_results.png')
        plt.close()
    
    # Return best parameters or defaults if not enough data
    result = {
        'aggregated': agg_results,
        'best_micro_f1': {
            'params': best_params_micro,
            'score': best_micro_f1
        },
        'best_hamming': {
            'params': best_params_hamming,
            'score': best_hamming
        }
    }
    
    # If we don't have enough results, provide sensible defaults
    if not best_params_micro:
        logger.warning("Not enough data to determine best parameters. Using default alpha=0.6, threshold=0.5")
        result['best_micro_f1'] = {
            'params': "alpha=0.6,threshold=0.5",
            'score': 0.0
        }
    
    return result

def evaluate_hybrid_vs_tfidf(df, best_alpha=0.6, best_threshold=0.5):
    """
    Compare hybrid model performance against TF-IDF only
    
    Args:
        df: DataFrame containing the dataset
        best_alpha: Best alpha value from tuning
        best_threshold: Best threshold value from tuning
    """
    logger.info("Evaluating hybrid model vs TF-IDF only...")
    
    # Convert labels_list if needed
    if isinstance(df['labels_list'].iloc[0], str):
        df['labels_list'] = df['labels_list'].str.strip('[]').str.split(',')
        df['labels_list'] = df['labels_list'].apply(lambda x: [item.strip().strip("'\"") for item in x])
    
    # Split data
    train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
    
    # Initialize and train both classifiers
    logger.info("Training hybrid model...")
    hybrid_classifier = InterestClassifier(alpha=best_alpha, threshold=best_threshold)
    hybrid_classifier.train(train_df)
    
    logger.info("Training TF-IDF only model...")
    tfidf_only_classifier = InterestClassifier(alpha=1.0, threshold=best_threshold)  # Alpha=1.0 means 100% TF-IDF
    tfidf_only_classifier.train(train_df)
    
    # Initialize BERT-only classifier for comparison (alpha=0 means 100% BERT)
    logger.info("Training BERT-only model...")
    bert_only_classifier = InterestClassifier(alpha=0.0, threshold=best_threshold)
    bert_only_classifier.train(train_df)  # We still need to train for the MLBinarizer
    
    # Evaluate on test set
    models = {
        'Hybrid (BERT + TF-IDF)': hybrid_classifier,
        'TF-IDF Only': tfidf_only_classifier,
        'BERT Only': bert_only_classifier
    }
    
    results = {}
    for model_name, classifier in models.items():
        logger.info(f"Evaluating {model_name}...")
        
        y_true = []
        y_pred = []
        prediction_times = []
        
        for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc=f"Testing {model_name}"):
            try:
                text = row['survey_answer']
                true_labels = row['labels_list']
                
                # Measure prediction time
                start_time = time.time()
                prediction = classifier.predict(text, return_scores=True)
                pred_labels = prediction['labels']
                end_time = time.time()
                
                prediction_times.append(end_time - start_time)
                
                y_true.append(true_labels)
                y_pred.append(pred_labels)
            except KeyboardInterrupt:
                logger.warning(f"KeyboardInterrupt during {model_name} evaluation. Processing collected data...")
                break
            except Exception as e:
                logger.error(f"Error predicting with {model_name} on sample {idx}: {e}")
                # Use empty prediction in case of error
                y_true.append(true_labels)
                y_pred.append([])
        
        # Skip further processing if we have no predictions (e.g., early interrupt)
        if not y_pred:
            logger.warning(f"No predictions collected for {model_name}. Skipping...")
            continue
            
        try:
            # Convert to multilabel format
            from sklearn.preprocessing import MultiLabelBinarizer
            mlb = MultiLabelBinarizer(classes=classifier.mlb.classes_)
            y_true_bin = mlb.fit_transform(y_true)
            y_pred_bin = mlb.transform(y_pred)
            
            # Calculate metrics
            micro_f1 = f1_score(y_true_bin, y_pred_bin, average='micro')
            macro_f1 = f1_score(y_true_bin, y_pred_bin, average='macro')
            h_loss = hamming_loss(y_true_bin, y_pred_bin)
            
            # Average prediction time
            avg_time = np.mean(prediction_times) if prediction_times else np.nan
            
            results[model_name] = {
                'micro_f1': micro_f1,
                'macro_f1': macro_f1,
                'hamming_loss': h_loss,
                'avg_prediction_time': avg_time,
                'samples_evaluated': len(y_pred)
            }
            
            logger.info(f"{model_name} results:")
            logger.info(f"  Micro-F1: {micro_f1:.4f}")
            logger.info(f"  Macro-F1: {macro_f1:.4f}")
            logger.info(f"  Hamming Loss: {h_loss:.4f}")
            logger.info(f"  Avg. Prediction Time: {avg_time:.4f} seconds")
            logger.info(f"  Samples evaluated: {len(y_pred)} of {len(test_df)}")
        except Exception as e:
            logger.error(f"Error calculating metrics for {model_name}: {e}")
            logger.error(traceback.format_exc())
    
    # Save results
    with open('model_comparison_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    # Plot comparison if we have results
    if results:
        plt.figure(figsize=(12, 10))
        
        # Plot F1 scores
        plt.subplot(2, 2, 1)
        model_names = list(results.keys())
        micro_f1_scores = [results[model]['micro_f1'] for model in model_names]
        plt.bar(model_names, micro_f1_scores)
        plt.ylabel('Micro F1-Score')
        plt.title('Micro F1-Score Comparison')
        plt.ylim(0, 1)
        plt.xticks(rotation=45)
        
        # Plot Macro F1
        plt.subplot(2, 2, 2)
        macro_f1_scores = [results[model]['macro_f1'] for model in model_names]
        plt.bar(model_names, macro_f1_scores)
        plt.ylabel('Macro F1-Score')
        plt.title('Macro F1-Score Comparison')
        plt.ylim(0, 1)
        plt.xticks(rotation=45)
        
        # Plot Hamming loss
        plt.subplot(2, 2, 3)
        hamming_scores = [results[model]['hamming_loss'] for model in model_names]
        plt.bar(model_names, hamming_scores)
        plt.ylabel('Hamming Loss')
        plt.title('Hamming Loss Comparison (Lower is Better)')
        plt.xticks(rotation=45)
        
        # Plot prediction time
        plt.subplot(2, 2, 4)
        times = [results[model]['avg_prediction_time'] for model in model_names]
        plt.bar(model_names, times)
        plt.ylabel('Average Prediction Time (seconds)')
        plt.title('Prediction Time Comparison')
        plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.savefig('model_comparison_results.png')
    
    return results

def main():
    """Main function to run the tuning process with checkpoint support"""
    # Load dataset
    try:
        df = pd.read_csv('survey_interest_dataset_enhanced.csv')
        logger.info(f"Dataset loaded with {len(df)} rows")
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        return
    
    try:
        # Step 1: Tune alpha parameter with checkpoint support
        logger.info("Step 1: Tuning alpha parameter")
        tuning_results = tune_alpha_parameter(
            df,
            alphas=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
            thresholds=[0.3, 0.4, 0.5, 0.6],
            n_splits=3,  # Using 3 folds for speed
            resume=True  # Enable checkpoint resuming
        )
        
        # If tuning was interrupted or failed, use default values
        if tuning_results is None or 'best_micro_f1' not in tuning_results or tuning_results['best_micro_f1']['params'] is None:
            logger.warning("Tuning did not complete successfully. Using default parameters.")
            best_alpha = 0.6  # Default value
            best_threshold = 0.5  # Default value
        else:
            # Extract best parameters
            best_param_str = tuning_results['best_micro_f1']['params']
            best_alpha = float(best_param_str.split(',')[0].split('=')[1])
            best_threshold = float(best_param_str.split(',')[1].split('=')[1])
        
        logger.info(f"Using parameters: alpha={best_alpha}, threshold={best_threshold}")
        
        # Step 2: Compare models
        logger.info("Step 2: Comparing hybrid vs TF-IDF only vs BERT only")
        comparison_results = evaluate_hybrid_vs_tfidf(df, best_alpha, best_threshold)
        
        # Step 3: Error analysis - only if we have comparison results
        if comparison_results and 'Hybrid (BERT + TF-IDF)' in comparison_results:
            logger.info("Step 3: Analyzing error cases")
            classifier = InterestClassifier(alpha=best_alpha, threshold=best_threshold)
            classifier.train(df)
            error_analysis = analyze_error_cases(df, classifier, n_examples=20)
        else:
            logger.warning("Skipping error analysis due to incomplete comparison results")
        
        logger.info("Model tuning and analysis complete!")
        logger.info(f"Best alpha: {best_alpha}")
        logger.info(f"Best threshold: {best_threshold}")
        logger.info("Results are saved to:")
        logger.info("  - alpha_tuning_results.json")
        logger.info("  - alpha_tuning_results.png")
        logger.info("  - model_comparison_results.json")
        logger.info("  - model_comparison_results.png")
        logger.info("  - error_analysis.txt")
        logger.info("  - error_analysis.json")
        
    except KeyboardInterrupt:
        logger.warning("\nProcess was interrupted by user. Results up to this point have been saved.")
        logger.info("You can resume the process by running the script again.")
    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        logger.error(traceback.format_exc())
        
if __name__ == "__main__":
    main()

def analyze_error_cases(df, classifier, n_examples=10):
    """
    Analyze error cases to understand where the model fails
    
    Args:
        df: DataFrame with test data
        classifier: Trained classifier
        n_examples: Number of error examples to analyze
    """
    logger.info("Analyzing error cases...")
    
    # Convert labels_list if needed
    if isinstance(df['labels_list'].iloc[0], str):
        df['labels_list'] = df['labels_list'].str.strip('[]').str.split(',')
        df['labels_list'] = df['labels_list'].apply(lambda x: [item.strip().strip("'\"") for item in x])
    
    # Split data
    _, test_df = train_test_split(df, test_size=0.2, random_state=42)
    
    # Find error cases
    error_cases = []
    
    try:
        for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Finding errors"):
            try:
                text = row['survey_answer']
                true_labels = set(row['labels_list'])
                
                # Get predictions with detailed info
                prediction = classifier.predict(text, return_scores=True)
                pred_labels = set(prediction['labels'])
                
                # Check if there's an error
                if true_labels != pred_labels:
                    missed_labels = true_labels - pred_labels
                    extra_labels = pred_labels - true_labels
                    
                    error_cases.append({
                        'text': text,
                        'true_labels': list(true_labels),  # Convert sets to lists for JSON serialization
                        'pred_labels': list(pred_labels),
                        'missed_labels': list(missed_labels),
                        'extra_labels': list(extra_labels),
                        'tfidf_scores': prediction['tfidf_scores'],
                        'bert_scores': prediction['bert_scores'],
                        'combined_scores': dict(prediction['sorted_scores'])
                    })
                    
                    # Save partial results periodically
                    if len(error_cases) % 50 == 0:
                        with open('error_analysis_partial.json', 'w') as f:
                            json.dump({'error_cases': error_cases[:n_examples]}, f, indent=2)
            except KeyboardInterrupt:
                logger.warning("KeyboardInterrupt during error analysis. Processing collected cases...")
                break
            except Exception as e:
                logger.error(f"Error analyzing sample {idx}: {e}")
                continue
    except KeyboardInterrupt:
        logger.warning("KeyboardInterrupt during error analysis. Processing collected cases...")
    except Exception as e:
        logger.error(f"Error during error analysis: {e}")
    
    logger.info(f"Found {len(error_cases)} error cases out of {len(test_df)} test examples")
    
    # If we don't have any error cases, return early
    if not error_cases:
        logger.warning("No error cases found for analysis.")
        return {
            'missed_labels': {},
            'extra_labels': {},
            'error_cases': []
        }
    
    # Analyze which labels are most frequently missed
    missed_label_counts = {}
    extra_label_counts = {}
    
    for case in error_cases:
        for label in case['missed_labels']:
            missed_label_counts[label] = missed_label_counts.get(label, 0) + 1
        
        for label in case['extra_labels']:
            extra_label_counts[label] = extra_label_counts.get(label, 0) + 1
    
    # Sort by count
    missed_label_counts = {k: v for k, v in sorted(missed_label_counts.items(), 
                                                   key=lambda item: item[1], 
                                                   reverse=True)}
    extra_label_counts = {k: v for k, v in sorted(extra_label_counts.items(), 
                                                  key=lambda item: item[1], 
                                                  reverse=True)}
    
    logger.info("Most frequently missed labels:")
    for label, count in list(missed_label_counts.items())[:5]:
        logger.info(f"  {label}: {count} times")
    
    logger.info("Most frequently incorrectly added labels:")
    for label, count in list(extra_label_counts.items())[:5]:
        logger.info(f"  {label}: {count} times")
    
    # Save detailed analysis of top N error cases
    with open('error_analysis.txt', 'w') as f:
        f.write("DETAILED ERROR ANALYSIS\n")
        f.write("======================\n\n")
        
        f.write("Most frequently missed labels:\n")
        for label, count in missed_label_counts.items():
            f.write(f"  {label}: {count} times\n")
        
        f.write("\nMost frequently incorrectly added labels:\n")
        for label, count in extra_label_counts.items():
            f.write(f"  {label}: {count} times\n")
        
        f.write("\n\nSAMPLE ERROR CASES\n")
        f.write("==================\n\n")
        
        for i, case in enumerate(error_cases[:n_examples]):
            f.write(f"Example {i+1}:\n")
            f.write(f"Text: {case['text']}\n")
            f.write(f"True labels: {', '.join(case['true_labels'])}\n")
            f.write(f"Predicted labels: {', '.join(case['pred_labels'])}\n")
            f.write(f"Missed labels: {', '.join(case['missed_labels'])}\n")
            f.write(f"Extra labels: {', '.join(case['extra_labels'])}\n")
            
            f.write("\nTF-IDF scores:\n")
            for label, score in sorted(case['tfidf_scores'].items(), key=lambda x: x[1], reverse=True):
                f.write(f"  {label}: {score:.4f}\n")
            
            f.write("\nBERT scores:\n")
            for label, score in sorted(case['bert_scores'].items(), key=lambda x: x[1], reverse=True):
                f.write(f"  {label}: {score:.4f}\n")
            
            f.write("\nCombined scores:\n")
            for label, score in sorted(case['combined_scores'].items(), key=lambda x: x[1], reverse=True):
                f.write(f"  {label}: {score:.4f}\n")
            
            f.write("\n" + "="*50 + "\n\n")
    
    # Also save as JSON for easier programmatic analysis
    with open('error_analysis.json', 'w') as f:
        json.dump({
            'missed_labels': missed_label_counts,
            'extra_labels': extra_label_counts,
            'error_cases': error_cases[:n_examples]
        }, f, indent=2)
    
    logger.info(f"Error analysis saved to 'error_analysis.txt' and 'error_analysis.json'")
    
    return {
        'missed_labels': missed_label_counts,
        'extra_labels': extra_label_counts,
        'error_cases': error_cases[:n_examples]
    }

2025-04-12 14:11:36,982 - INFO - Dataset loaded with 3018 rows
2025-04-12 14:11:36,984 - INFO - Step 1: Tuning alpha parameter
2025-04-12 14:11:36,985 - INFO - Starting alpha parameter tuning...
2025-04-12 14:11:36,996 - INFO - Processing fold 1/3
2025-04-12 14:11:36,999 - INFO - Initializing BERT zero-shot classifier with model: facebook/bart-large-mnli
Device set to use cpu
2025-04-12 14:11:37,714 - INFO - BERT classifier successfully initialized
2025-04-12 14:11:37,715 - INFO - Starting model training...
2025-04-12 14:11:37,716 - INFO - Preprocessing text data...
2025-04-12 14:11:37,760 - INFO - Target shape: (2012, 7)
2025-04-12 14:11:37,762 - INFO - Training set: 1609 samples, Test set: 403 samples
2025-04-12 14:11:37,763 - INFO - Creating and training TF-IDF pipeline...
2025-04-12 14:11:37,815 - INFO - TF-IDF pipeline trained successfully
2025-04-12 14:11:37,815 - INFO - Evaluating model on test set...
2025-04-12 14:11:37,831 - INFO - Hamming Loss: 0.0181
2025-04-12 14:11:37,831 

### BERT Zero-Shot Classifier with Optimizatio nCode 

In [2]:
import os
import time
import torch
import pickle
import logging
import numpy as np
from functools import lru_cache
from typing import List, Dict, Any, Union, Optional
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class OptimizedBERTZeroShotClassifier:
    """
    Optimized BERT zero-shot classifier with caching and performance improvements
    """
    def __init__(self, 
                 model_name: str = 'facebook/bart-large-mnli',
                 use_gpu: bool = torch.cuda.is_available(),
                 cache_size: int = 1024,
                 batch_size: int = 8):
        """
        Initialize the optimized BERT zero-shot classifier
        
        Args:
            model_name: The name of the pre-trained model to use
            use_gpu: Whether to use GPU for inference
            cache_size: Size of the LRU cache for classification results
            batch_size: Batch size for batch processing
        """
        self.model_name = model_name
        self.use_gpu = use_gpu
        self.cache_size = cache_size
        self.batch_size = batch_size
        self.cache_hits = 0
        self.cache_misses = 0
        
        # Initialize the model
        self._init_model()
        
        # Set up LRU cache for the classification method
        # Note: This properly caches the _classify_uncached method with self as first argument
        self._classify_cached = lru_cache(maxsize=cache_size)(self._classify_uncached)
    
    def _init_model(self):
        """Initialize the BERT model with optimizations"""
        start_time = time.time()
        
        try:
            # Set device
            device = 0 if self.use_gpu and torch.cuda.is_available() else -1
            logger.info(f"Initializing model {self.model_name} on device {device}")
            
            # Initialize tokenizer and model
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
            
            # Move model to device
            if device >= 0:
                self.model.to(f"cuda:{device}")
            
            # Create zero-shot pipeline
            self.classifier = pipeline(
                "zero-shot-classification",
                model=self.model,
                tokenizer=self.tokenizer,
                device=device
            )
            
            # For batch processing
            if self.use_gpu:
                # Enable optimizations for GPU
                self.model = self.model.half()  # Use half precision
                logger.info("Using half precision for GPU")
            
            logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
            
        except Exception as e:
            logger.error(f"Error initializing model: {e}")
            raise
    
    def _classify_uncached(self, text: str, categories_tuple: tuple, multi_label: bool = True) -> Dict[str, Any]:
        """
        Uncached version of the classification method
        
        Args:
            text: The text to classify
            categories_tuple: Tuple of category labels (for cache hashability)
            multi_label: Whether to use multi-label classification
            
        Returns:
            Classification results
        """
        # Convert tuple back to list for the classifier
        categories_list = list(categories_tuple)
        
        # Record time for performance monitoring
        start_time = time.time()
        
        # Call the classifier
        result = self.classifier(text, categories_list, multi_label=multi_label)
        
        # Add timing information
        result['processing_time'] = time.time() - start_time
        
        self.cache_misses += 1
        return result
    
    def classify_text(self, text: str, categories: List[str], multi_label: bool = True) -> Dict[str, Any]:
        """
        Classify text using the zero-shot classifier with caching
        
        Args:
            text: The text to classify
            categories: List of category labels
            multi_label: Whether to use multi-label classification
            
        Returns:
            Classification results
        """
        # Convert categories to tuple for hashability in cache
        categories_tuple = tuple(categories)
        
        # Check if it's in cache
        before_cache_misses = self.cache_misses
        
        # Use the cached version of classify_uncached
        result = self._classify_cached(text, categories_tuple, multi_label)
        
        # Check if it was a cache hit
        if self.cache_misses == before_cache_misses:
            self.cache_hits += 1
        
        return result
    
    def batch_classify(self, texts: List[str], categories: List[str], multi_label: bool = True) -> List[Dict[str, Any]]:
        """
        Classify multiple texts in batch for better performance
        
        Args:
            texts: List of texts to classify
            categories: List of category labels
            multi_label: Whether to use multi-label classification
            
        Returns:
            List of classification results
        """
        results = []
        start_time = time.time()
        
        # Convert categories to tuple once (for cache hashability)
        categories_tuple = tuple(categories)
        
        # Process all texts
        for text in texts:
            # Get cached or new classification
            result = self.classify_text(text, categories, multi_label)
            results.append(result)
        
        logger.info(f"Batch processed {len(texts)} texts in {time.time() - start_time:.2f} seconds")
        logger.info(f"Cache hits: {self.cache_hits}, misses: {self.cache_misses}")
        
        return results
    
    def get_cache_stats(self) -> Dict[str, int]:
        """Get cache statistics"""
        return {
            'hits': self.cache_hits,
            'misses': self.cache_misses,
            'size': self.cache_size,
            'current_usage': len(self._classify_cached.cache_info().currsize) if hasattr(self._classify_cached, 'cache_info') else 0
        }
    
    def clear_cache(self):
        """Clear the classification cache"""
        if hasattr(self._classify_cached, 'cache_clear'):
            self._classify_cached.cache_clear()
            logger.info("Cache cleared")


class HybridModelBERTOptimizer:
    """
    Class to optimize BERT usage in the hybrid interest classifier model
    """
    def __init__(self, 
                 bert_model_name: str = 'facebook/bart-large-mnli',
                 cache_dir: Optional[str] = './bert_cache',
                 use_gpu: bool = torch.cuda.is_available(),
                 cache_size: int = 2048,
                 batch_size: int = 16):
        """
        Initialize the BERT optimizer for hybrid model
        
        Args:
            bert_model_name: Name of the BERT model to use
            cache_dir: Directory to cache preprocessed BERT features
            use_gpu: Whether to use GPU
            cache_size: Size of the LRU cache
            batch_size: Batch size for processing
        """
        self.bert_model_name = bert_model_name
        self.cache_dir = cache_dir
        self.use_gpu = use_gpu
        self.cache_size = cache_size
        self.batch_size = batch_size
        
        # Create cache directory if needed
        if cache_dir and not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        
        # Initialize optimized classifier
        self.bert_classifier = OptimizedBERTZeroShotClassifier(
            model_name=bert_model_name,
            use_gpu=use_gpu,
            cache_size=cache_size,
            batch_size=batch_size
        )
        
        # Track statistics
        self.stats = {
            'disk_cache_hits': 0,
            'disk_cache_misses': 0,
            'total_predictions': 0,
            'batch_predictions': 0
        }
    
    def _get_cache_path(self, text: str, categories: List[str]) -> str:
        """Get path for disk cache file"""
        import hashlib
        
        # Create a hash of the text and categories
        text_hash = hashlib.md5(text.encode()).hexdigest()
        categories_str = '_'.join(sorted(categories))
        categories_hash = hashlib.md5(categories_str.encode()).hexdigest()
        
        return os.path.join(self.cache_dir, f"{text_hash}_{categories_hash}.pkl")
    
    def get_bert_predictions(self, text: str, categories: List[str], use_disk_cache: bool = True) -> Dict[str, float]:
        """
        Get BERT predictions with optimizations
        
        Args:
            text: Text to classify
            categories: Interest categories
            use_disk_cache: Whether to use disk cache
            
        Returns:
            Dictionary mapping categories to scores
        """
        self.stats['total_predictions'] += 1
        
        # Check disk cache first if enabled
        if use_disk_cache and self.cache_dir:
            cache_path = self._get_cache_path(text, categories)
            
            if os.path.exists(cache_path):
                try:
                    with open(cache_path, 'rb') as f:
                        cached_result = pickle.load(f)
                    
                    self.stats['disk_cache_hits'] += 1
                    return cached_result
                except Exception as e:
                    logger.warning(f"Error loading from disk cache: {e}")
        
        if use_disk_cache:
            self.stats['disk_cache_misses'] += 1
        
        # Get prediction from BERT classifier
        result = self.bert_classifier.classify_text(text, categories, multi_label=True)
        
        # Convert to simple dictionary format
        scores = dict(zip(result['labels'], result['scores']))
        
        # Save to disk cache if enabled
        if use_disk_cache and self.cache_dir:
            cache_path = self._get_cache_path(text, categories)
            try:
                with open(cache_path, 'wb') as f:
                    pickle.dump(scores, f)
            except Exception as e:
                logger.warning(f"Error saving to disk cache: {e}")
        
        return scores
    
    def batch_get_bert_predictions(self, texts: List[str], categories: List[str], use_disk_cache: bool = True) -> Dict[str, Dict[str, float]]:
        """
        Get BERT predictions for multiple texts with optimizations
        
        Args:
            texts: List of texts to classify
            categories: Interest categories
            use_disk_cache: Whether to use disk cache
            
        Returns:
            Dictionary mapping texts to prediction dictionaries
        """
        self.stats['batch_predictions'] += 1
        self.stats['total_predictions'] += len(texts)
        
        # Initialize results dictionary
        results = {}
        
        # Check disk cache first for each text
        uncached_texts = []
        
        if use_disk_cache and self.cache_dir:
            for text in texts:
                cache_path = self._get_cache_path(text, categories)
                
                if os.path.exists(cache_path):
                    try:
                        with open(cache_path, 'rb') as f:
                            results[text] = pickle.load(f)
                        self.stats['disk_cache_hits'] += 1
                    except Exception as e:
                        logger.warning(f"Error loading from disk cache: {e}")
                        uncached_texts.append(text)
                else:
                    uncached_texts.append(text)
                    self.stats['disk_cache_misses'] += 1
        else:
            uncached_texts = texts
        
        # Process uncached texts in smaller batches
        if uncached_texts:
            # Process in batches to avoid memory issues
            for i in range(0, len(uncached_texts), self.batch_size):
                batch_texts = uncached_texts[i:i+self.batch_size]
                
                # Get predictions for this batch
                for text in batch_texts:
                    result = self.bert_classifier.classify_text(text, categories, multi_label=True)
                    
                    # Convert to simple dictionary format
                    scores = dict(zip(result['labels'], result['scores']))
                    results[text] = scores
                    
                    # Save to disk cache if enabled
                    if use_disk_cache and self.cache_dir:
                        cache_path = self._get_cache_path(text, categories)
                        try:
                            with open(cache_path, 'wb') as f:
                                pickle.dump(scores, f)
                        except Exception as e:
                            logger.warning(f"Error saving to disk cache: {e}")
        
        return results
    
    def get_stats(self) -> Dict[str, int]:
        """Get optimizer statistics"""
        # Add in-memory cache stats
        bert_cache_stats = {
            'memory_cache_hits': self.bert_classifier.cache_hits,
            'memory_cache_misses': self.bert_classifier.cache_misses,
            'memory_cache_size': self.bert_classifier.cache_size
        }
        
        return {**self.stats, **bert_cache_stats}
    
    def clear_caches(self):
        """Clear all caches"""
        # Clear in-memory cache
        self.bert_classifier.clear_cache()
        
        # Clear disk cache if enabled
        if self.cache_dir and os.path.exists(self.cache_dir):
            import glob
            cache_files = glob.glob(os.path.join(self.cache_dir, "*.pkl"))
            for file in cache_files:
                try:
                    os.remove(file)
                except Exception as e:
                    logger.warning(f"Error removing cache file {file}: {e}")
            
            logger.info(f"Cleared {len(cache_files)} disk cache files")


# Function to optimize BERT predictions for the hybrid model
def optimize_bert_for_hybrid_model(texts: List[str], categories: List[str]) -> Dict[str, Dict[str, float]]:
    """
    Optimize BERT predictions for use in a hybrid model
    
    Args:
        texts: List of texts to classify
        categories: List of interest categories
        
    Returns:
        Dictionary mapping text to BERT predictions
    """
    # Initialize optimizer
    optimizer = HybridModelBERTOptimizer(
        bert_model_name='facebook/bart-large-mnli',
        cache_dir='./bert_cache',
        use_gpu=torch.cuda.is_available(),
        cache_size=2048,
        batch_size=16
    )
    
    # Get batch predictions
    return optimizer.batch_get_bert_predictions(texts, categories)


# Example usage
def main():
    """Example usage of the optimized BERT classifier"""
    # Example categories
    categories = ["Music", "Food", "Sports", "Technology", "Arts", "Travel", "Education"]
    
    # Example texts
    example_texts = [
        "I love hiking in the mountains and trying local foods wherever I travel.",
        "I'm a software developer who plays guitar in a band on weekends.",
        "I spend most of my time reading books and attending online courses.",
        "I enjoy painting landscapes and visiting art museums when I travel.",
        "I'm passionate about fitness and healthy cooking."
    ]
    
    # Initialize optimizer
    optimizer = HybridModelBERTOptimizer(cache_dir="./bert_cache")
    
    # Individual predictions
    print("\nIndividual predictions:")
    for text in example_texts[:2]:  # Just do 2 for demonstration
        print(f"\nText: '{text}'")
        start_time = time.time()
        result = optimizer.get_bert_predictions(text, categories)
        elapsed = time.time() - start_time
        
        print(f"Prediction time: {elapsed:.4f} seconds")
        for category, score in sorted(result.items(), key=lambda x: x[1], reverse=True):
            print(f"  {category}: {score:.4f}")
    
    # Batch predictions
    print("\nBatch predictions:")
    start_time = time.time()
    batch_results = optimizer.batch_get_bert_predictions(example_texts, categories)
    elapsed = time.time() - start_time
    
    print(f"Batch prediction time for {len(example_texts)} texts: {elapsed:.4f} seconds")
    print(f"Average time per text: {elapsed/len(example_texts):.4f} seconds")
    
    # Show a sample result
    sample_text = example_texts[0]
    print(f"\nSample result for: '{sample_text}'")
    for category, score in sorted(batch_results[sample_text].items(), key=lambda x: x[1], reverse=True):
        print(f"  {category}: {score:.4f}")
    
    # Show stats
    print("\nOptimizer stats:")
    stats = optimizer.get_stats()
    for key, value in stats.items():
        print(f"  {key}: {value}")


if __name__ == "__main__":
    main()

2025-04-12 19:57:28,238 - INFO - Initializing model facebook/bart-large-mnli on device -1
Device set to use cpu
2025-04-12 19:57:29,504 - INFO - Model loaded in 1.27 seconds



Individual predictions:

Text: 'I love hiking in the mountains and trying local foods wherever I travel.'
Prediction time: 3.1845 seconds
  Travel: 0.9361
  Food: 0.7752
  Arts: 0.0063
  Music: 0.0013
  Sports: 0.0011
  Education: 0.0009
  Technology: 0.0004

Text: 'I'm a software developer who plays guitar in a band on weekends.'
Prediction time: 2.1096 seconds
  Technology: 0.9120
  Music: 0.3961
  Arts: 0.0085
  Travel: 0.0048
  Sports: 0.0014
  Education: 0.0006
  Food: 0.0003

Batch predictions:
Batch prediction time for 5 texts: 5.4421 seconds
Average time per text: 1.0884 seconds

Sample result for: 'I love hiking in the mountains and trying local foods wherever I travel.'
  Travel: 0.9361
  Food: 0.7752
  Arts: 0.0063
  Music: 0.0013
  Sports: 0.0011
  Education: 0.0009
  Technology: 0.0004

Optimizer stats:
  disk_cache_hits: 2
  disk_cache_misses: 5
  total_predictions: 7
  batch_predictions: 1
  memory_cache_hits: 0
  memory_cache_misses: 5
  memory_cache_size: 2048
