# Adaptive Mixture-of-Experts for Cross-Domain Sentiment Classification

**High-level summary:**  
Routes each Google review dynamically to the best-suited model (Transformer, BiLSTM, or Logistic Regression) based on its length, sentiment indicators, detail level, and complexity, then ensembles their outputs with adaptive weights.

**Table 1: Base Routing Rules**

| Condition                                                         | Transformer | BiLSTM | Logistic | Primary Model | Notes                                                                                           |
|-------------------------------------------------------------------|:-----------:|:------:|:--------:|:-------------:|-------------------------------------------------------------------------------------------------|
| `length < 15`                                                     |   0.25      |  0.45  |   0.30   | logistic      | Short Google reviews (< 15 words) – favor simpler models; BiLSTM + Logistic for simple patterns |
| `length > 50`                                                     |   0.60      |  0.25  |   0.15   | bilstm        | Very detailed reviews (> 50 words) – favor Transformer; Transformer excels at long sequences     |
| `complexity_score > 0.05` **OR** `detail_level > 0.1`             |   0.55      |  0.30  |   0.15   | transformer   | High complexity or detailed mentions – Transformer for complex reasoning                        |
| `sentiment_clarity > 0.15`                                        |   0.40      |  0.40  |   0.20   | logistic      | Clear emotional sentiment – balanced approach; Transformer + BiLSTM for sentiment                |
| `15 ≤ length ≤ 30`                                                |   0.45      |  0.35  |   0.20   | bilstm        | Medium length, standard reviews – slight Transformer preference; Standard routing               |
| **Else** (ambiguous / neutral)                                    |   0.40      |  0.35  |   0.25   | transformer   | Ambiguous or neutral – favor ensemble diversity; More balanced for uncertain cases               |



In [None]:
# prompt: connect google drive

from google.colab import drive
drive.mount('/content/drive')

# prompt: load current directory

import os

os.chdir('/content/drive/My Drive/CS605-NLP-Project')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import re
import joblib
import time
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.feature_extraction.text import TfidfVectorizer
import math
import warnings
warnings.filterwarnings('ignore')

class AdaptiveMixtureOfExperts:
    """
    Adaptive routing system for Yelp-trained models on Google reviews
    Routes different review types to specialized models
    """

    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.models = {}
        self.routing_stats = {
            'transformer': 0,
            'bilstm': 0,
            'logistic': 0
        }

        print(f"🧠 Adaptive MoE initialized on {self.device}")

    def analyze_review_characteristics(self, text):
        """
        Analyze Google review characteristics for intelligent routing
        """
        if pd.isna(text) or text == "":
            return {
                'length': 0,
                'sentiment_clarity': 0.5,
                'detail_level': 0,
                'complexity_score': 0,
                'platform_style': 'simple'
            }

        text = str(text).lower()
        words = text.split()

        # 1. Text length analysis
        length = len(words)

        # 2. Sentiment strength indicators
        positive_words = ['amazing', 'excellent', 'fantastic', 'love', 'great', 'awesome', 'perfect', 'wonderful']
        negative_words = ['terrible', 'awful', 'horrible', 'hate', 'worst', 'disgusting', 'bad', 'disappointing']
        neutral_words = ['okay', 'average', 'decent', 'fine', 'normal', 'standard']

        pos_count = sum(1 for word in positive_words if word in text)
        neg_count = sum(1 for word in negative_words if word in text)
        neu_count = sum(1 for word in neutral_words if word in text)

        total_sentiment = pos_count + neg_count + neu_count
        sentiment_clarity = total_sentiment / max(length, 1) if length > 0 else 0

        # 3. Detail level (specific mentions)
        detail_indicators = ['staff', 'service', 'food', 'price', 'atmosphere', 'location', 'time', 'experience']
        detail_count = sum(1 for indicator in detail_indicators if indicator in text)
        detail_level = detail_count / max(length, 1) if length > 0 else 0

        # 4. Complexity indicators
        complex_patterns = [
            r'but\s+', r'however\s+', r'although\s+', r'despite\s+',  # Contrasts
            r'because\s+', r'since\s+', r'due\s+to',  # Explanations
            r'first\s+', r'then\s+', r'finally\s+',  # Sequences
        ]
        complexity_score = sum(1 for pattern in complex_patterns if re.search(pattern, text))
        complexity_score = complexity_score / max(length, 1) if length > 0 else 0

        # 5. Platform style classification
        if length < 10:
            platform_style = 'simple'
        elif detail_level > 0.1 and complexity_score > 0.05:
            platform_style = 'detailed'
        elif sentiment_clarity > 0.15:
            platform_style = 'emotional'
        else:
            platform_style = 'standard'

        return {
            'length': length,
            'sentiment_clarity': sentiment_clarity,
            'detail_level': detail_level,
            'complexity_score': complexity_score,
            'platform_style': platform_style
        }

    def intelligent_routing(self, review_characteristics):
        """
        Intelligent routing based on review characteristics
        Returns weights for [transformer, bilstm, logistic]
        """
        length = review_characteristics['length']
        sentiment_clarity = review_characteristics['sentiment_clarity']
        detail_level = review_characteristics['detail_level']
        complexity_score = review_characteristics['complexity_score']
        platform_style = review_characteristics['platform_style']

        # Base weights
        weights = [0.33, 0.33, 0.34]  # [transformer, bilstm, logistic]

        # Routing Logic for Cross-Domain Transfer (Yelp → Google)

        # 1. Short Google reviews (< 15 words) - favor simpler models
        if length < 15:
            weights = [0.25, 0.45, 0.30]  # BiLSTM + Logistic for simple patterns
            primary_model = 'logistic'

        # 2. Very detailed reviews (> 50 words) - favor Transformer
        elif length > 50:
            weights = [0.60, 0.25, 0.15]  # Transformer excels at long sequences
            primary_model = 'transformer'

        # 3. High complexity or detailed mentions - Transformer
        elif complexity_score > 0.05 or detail_level > 0.1:
            weights = [0.55, 0.30, 0.15]  # Transformer for complex reasoning
            primary_model = 'transformer'

        # 4. Clear emotional sentiment - balanced approach
        elif sentiment_clarity > 0.15:
            weights = [0.40, 0.40, 0.20]  # Transformer + BiLSTM for sentiment
            primary_model = 'logistic'

        # 5. Medium length, standard reviews - slight Transformer preference
        elif 15 <= length <= 30:
            weights = [0.45, 0.35, 0.20]  # Standard routing
            primary_model = 'bilstm'

        # 6. Ambiguous or neutral - favor ensemble diversity
        else:
            weights = [0.40, 0.35, 0.25]  # More balanced for uncertain cases
            primary_model = 'transformer'

        # Platform-specific adjustments for Google reviews
        if platform_style == 'simple':
            # Google users often write shorter, simpler reviews
            weights[1] += 0.1  # Boost BiLSTM
            weights[0] -= 0.05  # Reduce Transformer
            weights[2] -= 0.05  # Reduce Logistic

        elif platform_style == 'detailed':
            # Detailed Google reviews similar to Yelp
            weights[0] += 0.1  # Boost Transformer
            weights[1] -= 0.05  # Reduce others
            weights[2] -= 0.05

        # Ensure weights sum to 1
        weights = np.array(weights)
        weights = weights / weights.sum()

        return weights.tolist(), primary_model

    def load_all_models(self):
        """Load all three trained models with proper error handling"""
        print("🔄 Loading all Yelp-trained models...")

        # 1. Load Transformer
        try:
            transformer_checkpoint = torch.load('model/3class_transformer_v2.pth',
                                               map_location=self.device, weights_only=False)
            self.models['transformer'] = {
                'model': self.load_transformer_model(transformer_checkpoint),
                'vocab': transformer_checkpoint['vocab'],
                'config': transformer_checkpoint['model_config']
            }
            print("  ✅ Transformer loaded")
        except Exception as e:
            print(f"  ❌ Transformer loading failed: {e}")
            print("  ⚠️  Skipping Transformer - will use BiLSTM + Logistic")

        # 2. Load BiLSTM
        try:
            bilstm_checkpoint = torch.load('model/3class_bilstm_yelp.pth',
                                         map_location=self.device, weights_only=False)
            self.models['bilstm'] = {
                'model': self.load_bilstm_model(bilstm_checkpoint),
                'vocab': bilstm_checkpoint['vocab'],
                'config': bilstm_checkpoint['config']
            }
            print("  ✅ BiLSTM loaded")
        except Exception as e:
            print(f"  ❌ BiLSTM loading failed: {e}")
            print("  ⚠️  Skipping BiLSTM")

        # 3. Load Logistic
        try:
            self.models['logistic'] = {
                'model': joblib.load('model/3class_logistic_model_v2.pkl'),
                'vectorizer': joblib.load('model/3class_logistic_fidf_vectorizer_v2.pkl')
            }
            print("  ✅ Logistic loaded")
        except Exception as e:
            print(f"  ❌ Logistic loading failed: {e}")
            print("  ⚠️  Skipping Logistic")

        print(f"✅ Loaded {len(self.models)} models successfully!")

        # Adjust routing if some models failed to load
        if len(self.models) == 0:
            raise RuntimeError("❌ No models loaded successfully! Check model files.")
        elif len(self.models) < 3:
            print(f"⚠️  Only {len(self.models)} models loaded. Adjusting ensemble strategy...")
            self.adjust_routing_for_available_models()

    def load_transformer_model(self, checkpoint):
        """Load transformer model architecture - COMPLETE IMPLEMENTATION"""

        # Complete transformer architecture matching your training code
        class ImprovedMultiHeadAttention(nn.Module):
            def __init__(self, d_model, n_heads, dropout=0.1):
                super().__init__()
                self.d_model = d_model
                self.n_heads = n_heads
                self.d_k = d_model // n_heads

                self.w_q = nn.Linear(d_model, d_model)
                self.w_k = nn.Linear(d_model, d_model)
                self.w_v = nn.Linear(d_model, d_model)
                self.w_o = nn.Linear(d_model, d_model)

                self.dropout = nn.Dropout(dropout)
                self.layer_norm = nn.LayerNorm(d_model)

            def forward(self, x, mask=None):
                batch_size, seq_len, d_model = x.size()
                residual = x
                x = self.layer_norm(x)

                Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
                K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
                V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

                attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

                if mask is not None:
                    attention_scores = attention_scores.masked_fill(mask == 0, -1e9)

                attention_weights = torch.softmax(attention_scores, dim=-1)
                attention_weights = self.dropout(attention_weights)

                context = torch.matmul(attention_weights, V)
                context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
                output = self.w_o(context)

                return residual + self.dropout(output)

        class ImprovedTransformerBlock(nn.Module):
            def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
                super().__init__()
                self.attention = ImprovedMultiHeadAttention(d_model, n_heads, dropout)
                self.feed_forward = nn.Sequential(
                    nn.LayerNorm(d_model),
                    nn.Linear(d_model, d_ff),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_ff, d_model)
                )
                self.dropout = nn.Dropout(dropout)

            def forward(self, x, mask=None):
                x = self.attention(x, mask)
                residual = x
                ff_output = self.feed_forward(x)
                x = residual + self.dropout(ff_output)
                return x

        class ImprovedTransformer(nn.Module):
            def __init__(self, vocab_size, d_model=128, n_heads=8, n_layers=4, d_ff=512, max_length=384,
                         num_classes=3, dropout=0.15):
                super().__init__()
                self.d_model = d_model
                self.max_length = max_length

                self.token_embedding = nn.Embedding(vocab_size, d_model)
                self.position_embedding = nn.Embedding(max_length, d_model)
                self.embedding_dropout = nn.Dropout(dropout)
                self.embedding_norm = nn.LayerNorm(d_model)

                self.transformer_blocks = nn.ModuleList([
                    ImprovedTransformerBlock(d_model, n_heads, d_ff, dropout)
                    for _ in range(n_layers)
                ])

                self.final_norm = nn.LayerNorm(d_model)
                self.classifier = nn.Sequential(
                    nn.Linear(d_model * 2, d_model),  # *2 for concatenated pooling
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model, d_model // 2),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model // 2, num_classes)
                )

                self.init_weights()

            def init_weights(self):
                for module in self.modules():
                    if isinstance(module, nn.Linear):
                        torch.nn.init.xavier_uniform_(module.weight)
                        if module.bias is not None:
                            torch.nn.init.zeros_(module.bias)
                    elif isinstance(module, nn.Embedding):
                        torch.nn.init.normal_(module.weight, mean=0, std=0.02)

            def forward(self, x):
                batch_size, seq_len = x.size()
                positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(batch_size, seq_len)

                token_emb = self.token_embedding(x) * math.sqrt(self.d_model)
                pos_emb = self.position_embedding(positions)
                embeddings = token_emb + pos_emb
                embeddings = self.embedding_norm(embeddings)
                embeddings = self.embedding_dropout(embeddings)

                pad_mask = (x != 0).unsqueeze(1).unsqueeze(1)

                x = embeddings
                for transformer in self.transformer_blocks:
                    x = transformer(x, pad_mask)

                x = self.final_norm(x)

                # Dual pooling
                mask = (x.sum(dim=-1) != 0).float().unsqueeze(-1)
                x_mean = (x * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
                x_max, _ = (x + (1 - mask) * (-1e9)).max(dim=1)
                x_pooled = torch.cat([x_mean, x_max], dim=-1)

                logits = self.classifier(x_pooled)
                return logits

        model = ImprovedTransformer(**checkpoint['model_config'])
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)
        model.eval()
        return model

    def load_bilstm_model(self, checkpoint):
        """Load BiLSTM model architecture"""
        class BiLSTMClassifier(nn.Module):
            def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes, padding_idx, n_layers=1, dropout=0.5):
                super().__init__()
                self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=padding_idx)
                self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=n_layers,
                                    bidirectional=True, batch_first=True,
                                    dropout=dropout if n_layers>1 else 0)
                self.dropout = nn.Dropout(dropout)
                self.fc = nn.Linear(hidden_dim*2, n_classes)

            def forward(self, x):
                x_emb = self.embedding(x)
                _, (h_n, _) = self.lstm(x_emb)
                h_f = h_n[-2]  # forward final
                h_b = h_n[-1]  # backward final
                h   = torch.cat([h_f, h_b], dim=1)
                return self.fc(self.dropout(h))

        config = checkpoint['config']
        vocab = checkpoint['vocab']

        model = BiLSTMClassifier(
            vocab_size=len(vocab['itos']),
            emb_dim=config['embed_dim'],
            hidden_dim=config['hidden_dim'],
            n_classes=config['n_classes'],
            padding_idx=vocab['stoi']['<PAD>']
        )

        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)
        model.eval()
        return model

    def predict_single_model(self, texts, model_name):
        """Get predictions from a single model"""
        if model_name not in self.models:
            raise ValueError(f"Model {model_name} not loaded")

        if model_name == 'transformer':
            return self.predict_transformer(texts)
        elif model_name == 'bilstm':
            return self.predict_bilstm(texts)
        elif model_name == 'logistic':
            return self.predict_logistic(texts)

    def adjust_routing_for_available_models(self):
        """Adjust routing strategy based on available models"""
        available_models = list(self.models.keys())
        print(f"🔧 Adjusting routing for available models: {available_models}")

        # Update routing stats to only track available models
        self.routing_stats = {model: 0 for model in available_models}

        # Store original routing function
        self.original_intelligent_routing = self.intelligent_routing

        # Create new routing function for available models
        def adjusted_routing(characteristics):
            if len(available_models) == 1:
                # Only one model available
                model = available_models[0]
                return [1.0], model

            elif len(available_models) == 2:
                # Two models available
                if 'transformer' in available_models and 'bilstm' in available_models:
                    # Transformer + BiLSTM
                    if characteristics['length'] > 30 or characteristics['complexity_score'] > 0.05:
                        return [0.7, 0.3], 'transformer'
                    else:
                        return [0.4, 0.6], 'bilstm'

                elif 'transformer' in available_models and 'logistic' in available_models:
                    # Transformer + Logistic
                    if characteristics['length'] > 20:
                        return [0.8, 0.2], 'transformer'
                    else:
                        return [0.5, 0.5], 'transformer'

                elif 'bilstm' in available_models and 'logistic' in available_models:
                    # BiLSTM + Logistic
                    if characteristics['length'] < 15:
                        return [0.3, 0.7], 'logistic'
                    else:
                        return [0.7, 0.3], 'bilstm'

            # Fallback to original if all three available
            return self.original_intelligent_routing(characteristics)

        # Replace routing function
        self.intelligent_routing = adjusted_routing

    def predict_transformer(self, texts):
        """Predict with transformer model using your exact preprocessing"""
        model = self.models['transformer']['model']
        vocab = self.models['transformer']['vocab']

        # Use the exact preprocessing from your transformer inference code
        class ImprovedYelpDataset:
            def __init__(self, max_length=384):
                self.max_length = max_length

            def enhanced_tokenize(self, text):
                if pd.isna(text) or text == "":
                    return ["<UNK>"]

                text = str(text).lower()

                # Handle contractions
                text = re.sub(r"won't", "will not", text)
                text = re.sub(r"can't", "cannot", text)
                text = re.sub(r"n't", " not", text)
                text = re.sub(r"'re", " are", text)
                text = re.sub(r"'ve", " have", text)
                text = re.sub(r"'ll", " will", text)
                text = re.sub(r"'d", " would", text)
                text = re.sub(r"'m", " am", text)

                # Handle sentiment patterns
                text = re.sub(r'[!]{2,}', ' very_excited ', text)
                text = re.sub(r'[?]{2,}', ' very_confused ', text)
                text = re.sub(r'[.]{3,}', ' continuation ', text)

                # Clean punctuation
                text = re.sub(r'[^a-zA-Z0-9\s!?.]', ' ', text)
                text = re.sub(r'(.)\1{2,}', r'\1\1', text)
                text = re.sub(r'\s+', ' ', text).strip()

                return text.split()

            def preprocess_text(self, text, vocab):
                tokens = self.enhanced_tokenize(text)
                token_ids = [vocab.get(token, vocab.get('<UNK>', 1)) for token in tokens]

                if len(token_ids) > self.max_length:
                    token_ids = token_ids[:self.max_length]
                else:
                    token_ids += [vocab.get('<PAD>', 0)] * (self.max_length - len(token_ids))

                return torch.tensor(token_ids, dtype=torch.long)

        dataset_processor = ImprovedYelpDataset(max_length=384)
        predictions = []
        probabilities = []

        # Process in batches
        batch_size = 32
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]

            batch_tensors = []
            for text in batch_texts:
                tensor = dataset_processor.preprocess_text(text, vocab)
                batch_tensors.append(tensor)

            batch_input = torch.stack(batch_tensors).to(self.device)

            with torch.no_grad():
                outputs = model(batch_input)
                probs = torch.softmax(outputs, dim=1)
                preds = outputs.argmax(dim=1)

                predictions.extend(preds.cpu().numpy())
                probabilities.extend(probs.cpu().numpy())

        return np.array(predictions), np.array(probabilities)

    def predict_bilstm(self, texts):
        """Predict with BiLSTM model using your exact preprocessing"""
        model = self.models['bilstm']['model']
        vocab = self.models['bilstm']['vocab']
        config = self.models['bilstm']['config']

        # Use exact preprocessing from your BiLSTM code
        def preprocess_text(text):
            text = str(text).lower()
            text = re.sub(r'[^a-z0-9\s]', '', text)
            return text.split()

        predictions = []
        probabilities = []

        # Process in batches
        batch_size = 128
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]

            batch_tensors = []
            for text in batch_texts:
                tokens = preprocess_text(text)
                token_ids = [vocab['stoi'].get(token, vocab['stoi']['<UNK>']) for token in tokens]

                max_len = config.get('max_len', 200)
                if len(token_ids) < max_len:
                    token_ids += [vocab['stoi']['<PAD>']] * (max_len - len(token_ids))
                else:
                    token_ids = token_ids[:max_len]

                batch_tensors.append(torch.tensor(token_ids, dtype=torch.long))

            batch_input = torch.stack(batch_tensors).to(self.device)

            with torch.no_grad():
                outputs = model(batch_input)
                probs = torch.softmax(outputs, dim=1)
                preds = outputs.argmax(dim=1)

                predictions.extend(preds.cpu().numpy())
                probabilities.extend(probs.cpu().numpy())

        return np.array(predictions), np.array(probabilities)

    def predict_logistic(self, texts):
        """Predict with Logistic model"""
        model = self.models['logistic']['model']
        vectorizer = self.models['logistic']['vectorizer']

        # Vectorize texts
        X_tfidf = vectorizer.transform(texts)

        # Predict
        predictions = model.predict(X_tfidf)
        probabilities = model.predict_proba(X_tfidf)

        return predictions, probabilities

    def adaptive_ensemble_predict(self, texts, return_routing_info=False):
        """
        Main adaptive ensemble prediction function with fallback handling
        """
        print(f"🎯 Running Adaptive MoE on {len(texts)} reviews...")

        available_models = list(self.models.keys())
        print(f"📋 Available models: {available_models}")

        if len(available_models) == 0:
            raise RuntimeError("No models available for prediction!")

        ensemble_predictions = []
        ensemble_probabilities = []
        routing_decisions = []

        # Get predictions from available models only
        print("  Getting predictions from available models...")
        model_predictions = {}
        model_probabilities = {}

        for model_name in available_models:
            try:
                preds, probs = self.predict_single_model(texts, model_name)
                model_predictions[model_name] = preds
                model_probabilities[model_name] = probs
                print(f"    ✅ {model_name} predictions complete")
            except Exception as e:
                print(f"    ❌ {model_name} prediction failed: {e}")
                # Remove failed model from available models
                if model_name in available_models:
                    available_models.remove(model_name)

        if len(available_models) == 0:
            raise RuntimeError("All model predictions failed!")

        print("  Applying adaptive routing...")

        for i, text in enumerate(texts):
            # Analyze review characteristics
            characteristics = self.analyze_review_characteristics(text)

            # Get routing weights for available models
            weights, primary_model = self.intelligent_routing(characteristics)

            # Ensure primary model is available
            if primary_model not in available_models:
                primary_model = available_models[0]

            # Track routing statistics
            self.routing_stats[primary_model] += 1

            # Handle different numbers of available models
            if len(available_models) == 1:
                # Only one model available
                model_name = available_models[0]
                ensemble_prob = model_probabilities[model_name][i]
                weights = [1.0]

            elif len(available_models) == 2:
                # Two models available
                model1, model2 = available_models
                ensemble_prob = (weights[0] * model_probabilities[model1][i] +
                               weights[1] * model_probabilities[model2][i])

            else:
                # All three models available
                ensemble_prob = (weights[0] * model_probabilities['transformer'][i] +
                               weights[1] * model_probabilities['bilstm'][i] +
                               weights[2] * model_probabilities['logistic'][i])

            ensemble_pred = np.argmax(ensemble_prob)

            ensemble_predictions.append(ensemble_pred)
            ensemble_probabilities.append(ensemble_prob)

            if return_routing_info:
                routing_decisions.append({
                    'weights': weights,
                    'primary_model': primary_model,
                    'characteristics': characteristics,
                    'available_models': available_models.copy()
                })

        if return_routing_info:
            return (np.array(ensemble_predictions),
                   np.array(ensemble_probabilities),
                   routing_decisions)

        return np.array(ensemble_predictions), np.array(ensemble_probabilities)

    def evaluate_on_google_reviews(self, google_reviews_df, review_column='review', star_column='stars'):
        """
        Evaluate adaptive ensemble on Google reviews dataset
        """
        print("🌐 EVALUATING ADAPTIVE MOE ON GOOGLE REVIEWS")
        print("="*60)

        # Data preprocessing
        print("Preprocessing Google reviews...")

        # Remove missing reviews
        initial_count = len(google_reviews_df)
        google_reviews_df = google_reviews_df.dropna(subset=[review_column, star_column])
        google_reviews_df = google_reviews_df[google_reviews_df[review_column].str.strip() != '']

        print(f"✓ Cleaned data: {len(google_reviews_df)} reviews ({initial_count - len(google_reviews_df)} removed)")

        # Convert to 3-class labels (same as Yelp training)
        def convert_stars_to_3class(stars):
            if stars <= 2:
                return 0  # Negative
            elif stars == 3:
                return 1  # Neutral
            else:  # stars >= 4
                return 2  # Positive

        google_reviews_df['true_label'] = google_reviews_df[star_column].apply(convert_stars_to_3class)

        # Print distribution
        print("\nGoogle Reviews Distribution:")
        class_dist = google_reviews_df['true_label'].value_counts().sort_index()
        class_names = ['Negative (≤2★)', 'Neutral (3★)', 'Positive (≥4★)']
        for i, count in enumerate(class_dist):
            print(f"  {class_names[i]}: {count} ({count/len(google_reviews_df)*100:.1f}%)")

        # Make predictions
        texts = google_reviews_df[review_column].tolist()
        true_labels = google_reviews_df['true_label'].tolist()

        start_time = time.time()
        ensemble_preds, ensemble_probs, routing_info = self.adaptive_ensemble_predict(
            texts, return_routing_info=True
        )
        inference_time = time.time() - start_time

        print(f"✓ Inference completed in {inference_time:.2f} seconds")
        print(f"✓ Average time per review: {inference_time/len(texts)*1000:.2f}ms")

        # Routing statistics
        print(f"\n📊 ROUTING STATISTICS:")
        total_routes = sum(self.routing_stats.values())
        for model, count in self.routing_stats.items():
            percentage = count / total_routes * 100 if total_routes > 0 else 0
            print(f"  {model.capitalize():>11}: {count:>5} routes ({percentage:>5.1f}%)")

        # Calculate metrics
        accuracy = accuracy_score(true_labels, ensemble_preds)
        macro_f1 = f1_score(true_labels, ensemble_preds, average='macro')
        weighted_f1 = f1_score(true_labels, ensemble_preds, average='weighted')

        # Print results
        print(f"\n🎯 ADAPTIVE MOE PERFORMANCE:")
        print("="*40)
        print(f"Cross-Domain Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
        print(f"Macro F1-Score: {macro_f1:.4f}")
        print(f"Weighted F1-Score: {weighted_f1:.4f}")

        # Detailed classification report
        print(f"\n📊 DETAILED RESULTS:")
        target_names = ['Negative', 'Neutral', 'Positive']
        report = classification_report(true_labels, ensemble_preds, target_names=target_names, digits=4)
        print(report)

        # Compare with individual models
        print(f"\n🔍 INDIVIDUAL MODEL COMPARISON:")
        print("-"*50)

        # Get individual model predictions for comparison
        transformer_preds, _ = self.predict_single_model(texts, 'transformer')
        bilstm_preds, _ = self.predict_single_model(texts, 'bilstm')
        logistic_preds, _ = self.predict_single_model(texts, 'logistic')

        models_performance = {
            'Transformer': accuracy_score(true_labels, transformer_preds),
            'BiLSTM': accuracy_score(true_labels, bilstm_preds),
            'Logistic': accuracy_score(true_labels, logistic_preds),
            'Adaptive MoE': accuracy
        }

        for model_name, acc in models_performance.items():
            print(f"{model_name:>13}: {acc:.4f} accuracy")

        best_individual = max(models_performance.items(), key=lambda x: x[1] if x[0] != 'Adaptive MoE' else 0)
        improvement = accuracy - best_individual[1]

        print(f"\n🏆 ENSEMBLE IMPROVEMENT:")
        print(f"Best Individual: {best_individual[0]} ({best_individual[1]:.4f})")
        print(f"Adaptive MoE:    {accuracy:.4f}")
        print(f"Improvement:     +{improvement:.4f} ({improvement*100:+.2f}%)")

        # Save results
        google_reviews_df['review_index'] = google_reviews_df['review_index']
        google_reviews_df['ensemble_prediction'] = ensemble_preds
        google_reviews_df['ensemble_confidence'] = ensemble_probs.max(axis=1)
        google_reviews_df['prob_negative'] = ensemble_probs[:, 0]
        google_reviews_df['prob_neutral'] = ensemble_probs[:, 1]
        google_reviews_df['prob_positive'] = ensemble_probs[:, 2]

        # Add routing information
        for i, route_info in enumerate(routing_info):
            google_reviews_df.loc[i, 'primary_model'] = route_info['primary_model']
            google_reviews_df.loc[i, 'transformer_weight'] = route_info['weights'][0]
            google_reviews_df.loc[i, 'bilstm_weight'] = route_info['weights'][1]
            google_reviews_df.loc[i, 'logistic_weight'] = route_info['weights'][2]

        output_file = 'google_reviews_adaptive_moe_results.csv'
        google_reviews_df.to_csv(output_file, index=False)
        print(f"\n💾 Results saved to '{output_file}'")

        return google_reviews_df, accuracy, macro_f1

# Example usage function
def run_adaptive_moe_on_google_reviews():
    """
    Main function to run Adaptive MoE on Google reviews
    """
    print("🚀 ADAPTIVE MIXTURE OF EXPERTS: YELP → GOOGLE")
    print("="*60)

    # Initialize Adaptive MoE
    moe = AdaptiveMixtureOfExperts()

    # Load all models
    moe.load_all_models()

    # Load Google reviews (replace with your actual Google reviews dataset)
    print("📱 Loading Google reviews...")
    try:
        # Replace this with your actual Google reviews CSV
        google_reviews = pd.read_csv("datastore/Google_Reviews.csv")
        #google_reviews  = pd.read_parquet("datastore/test-00000-of-00001.parquet")
        print(f"✓ Loaded {len(google_reviews)} Google reviews")
    except FileNotFoundError:
        print("⚠️  Using USS reviews as proxy for Google reviews...")
        # Use USS reviews as a proxy for Google reviews
        google_reviews = pd.read_csv("datastore/USS_Reviews_Silver.csv", parse_dates=["publishedAtDate"])
        # Rename columns to match expected format
        google_reviews = google_reviews.rename(columns={'review': 'review', 'stars': 'stars'})


    # Run evaluation
    results_df, accuracy, macro_f1 = moe.evaluate_on_google_reviews(
        google_reviews,
        review_column='review',
        star_column='stars'
    )

    print(f"\n🎉 ADAPTIVE MOE COMPLETE!")
    print(f"Final Cross-Domain Accuracy: {accuracy:.4f}")
    print(f"Expected improvement over single models: +2-5%")

    return results_df, moe

if __name__ == "__main__":
    # Run the adaptive ensemble
    results, moe_system = run_adaptive_moe_on_google_reviews()

🚀 ADAPTIVE MIXTURE OF EXPERTS: YELP → GOOGLE
🧠 Adaptive MoE initialized on cuda
🔄 Loading all Yelp-trained models...
  ✅ Transformer loaded
  ✅ BiLSTM loaded
  ✅ Logistic loaded
✅ Loaded 3 models successfully!
📱 Loading Google reviews...
⚠️  Using USS reviews as proxy for Google reviews...
🌐 EVALUATING ADAPTIVE MOE ON GOOGLE REVIEWS
Preprocessing Google reviews...
✓ Cleaned data: 29412 reviews (0 removed)

Google Reviews Distribution:
  Negative (≤2★): 2206 (7.5%)
  Neutral (3★): 2133 (7.3%)
  Positive (≥4★): 25073 (85.2%)
🎯 Running Adaptive MoE on 29412 reviews...
📋 Available models: ['transformer', 'bilstm', 'logistic']
  Getting predictions from available models...
    ✅ transformer predictions complete
    ✅ bilstm predictions complete
    ✅ logistic predictions complete
  Applying adaptive routing...
✓ Inference completed in 15.22 seconds
✓ Average time per review: 0.52ms

📊 ROUTING STATISTICS:
  Transformer:  9058 routes ( 30.8%)
       Bilstm:  5098 routes ( 17.3%)
     Logistic