In [None]:
"""
Advanced Marathi Sentiment Analysis
A streamlined, high-performance solution for sentiment analysis of Marathi text
Optimized for Google Colab execution
"""

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoTokenizer, AutoModel

# Set random seeds for reproducibility
def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed
set_seed()

# Mount Google Drive for data access
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
except:
    IN_COLAB = False
    print("Not running in Google Colab")

# =============================================
# MEMORY OPTIMIZATION FOR COLAB
# =============================================

def clear_memory():
    """Clear memory to prevent OOM errors in Colab"""
    import gc
    gc.collect()
    torch.cuda.empty_cache()

# Function to display Colab RAM usage
def display_memory_usage():
    """Display current RAM usage in Colab"""
    if IN_COLAB:
        try:
            from psutil import virtual_memory
            ram = virtual_memory()
            print(f"RAM Usage: {ram.percent}% ({ram.used / 1e9:.1f}GB / {ram.total / 1e9:.1f}GB)")
        except:
            print("Could not display memory usage")

# =============================================
# DATA LOADING AND PREPROCESSING
# =============================================

class MarathiTextPreprocessor:
    """Efficient text preprocessor for Marathi language"""

    def __init__(self, remove_stopwords=True):
        self.remove_stopwords = remove_stopwords
        self.stopwords = self._load_stopwords()

    def _load_stopwords(self):
        """Load common Marathi stopwords"""
        return {
            'आणि', 'आहे', 'तो', 'ती', 'ते', 'होते', 'होता', 'होती', 'आहेत',
            'या', 'च', 'ला', 'तर', 'पण', 'की', 'म्हणून', 'हे', 'त्या', 'तू', 'मी',
            'आम्ही', 'आपण', 'तुम्ही', 'त्यांचा', 'त्यांची', 'त्यांचे', 'वर', 'मध्ये'
        }

    def process(self, text):
        """Process Marathi text for sentiment analysis"""
        if not isinstance(text, str):
            return ""

        # Remove URLs
        text = re.sub(r'https?://\S+|www\.\S+', '', text)

        # Remove HTML tags
        text = re.sub(r'<.*?>', '', text)

        # Normalize whitespace
        text = re.sub(r'\s+', ' ', text).strip()

        # Remove non-Devanagari characters (except whitespace and punctuation)
        devanagari_pattern = r'[^\u0900-\u097F\s\d.,!?;:"""''()+-]'
        text = re.sub(devanagari_pattern, '', text)

        # Remove stopwords if enabled
        if self.remove_stopwords:
            words = text.split()
            words = [word for word in words if word.lower() not in self.stopwords]
            text = ' '.join(words)

        return text

def load_dataset(file_path):
    """Load and prepare Marathi sentiment dataset"""
    try:
        # Load dataset
        df = pd.read_csv(file_path)

        # Rename columns if needed
        if 'tweet' in df.columns and 'comment' not in df.columns:
            df = df.rename(columns={'tweet': 'comment'})

        required_columns = ['comment', 'label']
        for col in required_columns:
            if col not in df.columns:
                raise ValueError(f"Required column '{col}' not found in the dataset")

        print(f"Dataset loaded successfully with {len(df)} rows")
        print("\nClass distribution:")
        print(df['label'].value_counts(normalize=True) * 100)

        # Display sample data
        print("\nSample data:")
        print(df.sample(5))

        # Remap labels to ensure they are non-negative integers starting from 0
        # Map: -1 → 0, 0 → 1, 1 → 2
        if df['label'].min() < 0:
            print("\nRemapping labels to non-negative integers...")
            label_mapping = {-1: 0, 0: 1, 1: 2}
            df['original_label'] = df['label'].copy()  # Keep original labels
            df['label'] = df['label'].map(label_mapping)
            print("New label distribution:")
            print(df['label'].value_counts(normalize=True) * 100)

        return df
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

# =============================================
# DATASET CLASS
# =============================================

class MarathiSentimentDataset(Dataset):
    """Dataset for Marathi sentiment analysis"""

    def __init__(self, texts, labels, tokenizer, max_length=128, preprocessor=None):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.preprocessor = preprocessor

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # Apply preprocessing if available
        if self.preprocessor:
            text = self.preprocessor.process(text)

        # Tokenize text
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Ensure tensors have correct data type
        return {
            'input_ids': encoding['input_ids'].flatten().long(),
            'attention_mask': encoding['attention_mask'].flatten().long(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# =============================================
# MODEL ARCHITECTURE
# =============================================

class MarathiSentimentClassifier(nn.Module):
    """Efficient sentiment classifier for Marathi text"""

    def __init__(self, model_name, num_labels=3, dropout_rate=0.2):
        super(MarathiSentimentClassifier, self).__init__()

        # Load pre-trained model
        self.transformer = AutoModel.from_pretrained(model_name)
        self.num_labels = num_labels

        # Get hidden size from config
        self.hidden_size = self.transformer.config.hidden_size

        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.GELU(),
            nn.LayerNorm(self.hidden_size // 2),
            nn.Dropout(dropout_rate),
            nn.Linear(self.hidden_size // 2, num_labels)
        )

    def forward(self, input_ids, attention_mask, labels=None):
        # Ensure input tensors have correct data type
        input_ids = input_ids.long()
        attention_mask = attention_mask.long()

        # Get transformer outputs
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        # Get pooled output
        if hasattr(outputs, 'pooler_output'):
            pooled_output = outputs.pooler_output
        else:
            # If pooler_output is not available, use the [CLS] token output
            pooled_output = outputs.last_hidden_state[:, 0, :]

        # Apply classifier
        logits = self.classifier(pooled_output)

        # Calculate loss if labels are provided
        loss = None
        if labels is not None:
            # Ensure labels are within valid range
            if torch.min(labels) < 0 or torch.max(labels) >= self.num_labels:
                raise ValueError(f"Labels must be in range [0, {self.num_labels-1}], but got range [{torch.min(labels).item()}, {torch.max(labels).item()}]")

            # Use label smoothing for better generalization
            loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
            loss = loss_fn(logits, labels)

        return {
            'loss': loss,
            'logits': logits
        }

# =============================================
# TRAINER CLASS
# =============================================

class MarathiSentimentTrainer:
    """Efficient trainer for Marathi sentiment analysis"""

    def __init__(self, model, device, mixed_precision=True):
        self.model = model
        self.device = device
        self.mixed_precision = mixed_precision

        # Initialize mixed precision scaler
        if self.mixed_precision:
            self.scaler = GradScaler()

        # Initialize best model state
        self.best_model_state = None

    def train(self, train_loader, val_loader, optimizer, scheduler=None,
              num_epochs=10, patience=3, clip_grad_norm=1.0):
        """Train model with early stopping"""
        best_val_f1 = 0
        patience_counter = 0
        training_history = {
            'train_loss': [],
            'val_loss': [],
            'val_accuracy': [],
            'val_f1': [],
            'learning_rates': []
        }

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")

            # Display memory usage (Colab optimization)
            display_memory_usage()

            # Training
            train_loss = self._train_epoch(train_loader, optimizer, scheduler, clip_grad_norm)
            training_history['train_loss'].append(train_loss)

            # Store current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            training_history['learning_rates'].append(current_lr)

            # Clear memory before validation
            clear_memory()

            # Validation
            val_metrics = self._evaluate(val_loader)

            # Store metrics
            training_history['val_loss'].append(val_metrics['loss'])
            training_history['val_accuracy'].append(val_metrics['accuracy'])
            training_history['val_f1'].append(val_metrics['f1'])

            # Print metrics
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_metrics['loss']:.4f}")
            print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
            print(f"Val F1: {val_metrics['f1']:.4f}")

            # Print per-class metrics
            print("\nPer-class metrics:")
            class_report = val_metrics['class_report']
            for class_idx in sorted([int(k) for k in class_report.keys() if k.isdigit()]):
                print(f"Class {class_idx}: Precision={class_report[str(class_idx)]['precision']:.4f}, "
                      f"Recall={class_report[str(class_idx)]['recall']:.4f}, "
                      f"F1={class_report[str(class_idx)]['f1-score']:.4f}")

            # Check for improvement
            if val_metrics['f1'] > best_val_f1:
                best_val_f1 = val_metrics['f1']
                patience_counter = 0

                # Save best model state
                self.best_model_state = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_f1': val_metrics['f1'],
                    'val_accuracy': val_metrics['accuracy']
                }

                print(f"New best model saved (F1: {best_val_f1:.4f})")
            else:
                patience_counter += 1
                print(f"No improvement. Patience: {patience_counter}/{patience}")

                # Early stopping
                if patience_counter >= patience:
                    print("Early stopping triggered.")
                    break

            # Clear memory after each epoch
            clear_memory()

        # Restore best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state['model_state_dict'])
            print(f"Restored best model from epoch {self.best_model_state['epoch'] + 1} with F1: {self.best_model_state['val_f1']:.4f}")

        return training_history

    def _train_epoch(self, dataloader, optimizer, scheduler=None, clip_grad_norm=1.0):
        """Train model for one epoch"""
        self.model.train()
        total_loss = 0

        # Use tqdm for progress tracking
        progress_bar = tqdm(dataloader, desc="Training")

        for step, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(self.device) for k, v in batch.items()}

            # Forward pass with mixed precision
            if self.mixed_precision:
                with autocast():
                    outputs = self.model(**batch)
                    loss = outputs['loss']

                # Backward pass with gradient scaling
                optimizer.zero_grad()
                self.scaler.scale(loss).backward()

                # Clip gradients
                if clip_grad_norm > 0:
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad_norm)

                # Update weights
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                # Standard training without mixed precision
                outputs = self.model(**batch)
                loss = outputs['loss']

                # Backward pass
                optimizer.zero_grad()
                loss.backward()

                # Clip gradients
                if clip_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad_norm)

                # Update weights
                optimizer.step()

            # Update learning rate
            if scheduler is not None:
                scheduler.step()

            # Update loss
            total_loss += loss.item()

            # Update progress bar
            progress_bar.set_postfix({"loss": f"{total_loss / (step + 1):.4f}"})

        return total_loss / len(dataloader)

    def _evaluate(self, dataloader):
        """Evaluate model on validation set"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}

                # Forward pass
                outputs = self.model(**batch)

                # Get loss
                loss = outputs['loss']
                total_loss += loss.item()

                # Get predictions
                logits = outputs['logits']
                preds = torch.argmax(logits, dim=1).cpu().numpy()
                labels = batch['labels'].cpu().numpy()

                all_preds.extend(preds)
                all_labels.extend(labels)

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='weighted')

        # Calculate per-class metrics
        class_report = classification_report(all_labels, all_preds, output_dict=True)

        metrics = {
            'loss': total_loss / len(dataloader),
            'accuracy': accuracy,
            'f1': f1,
            'class_report': class_report
        }

        return metrics

    def predict(self, dataloader):
        """Make predictions on test set"""
        self.model.eval()
        all_preds = []
        all_logits = []

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Predicting"):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}

                # Forward pass
                outputs = self.model(**{k: v for k, v in batch.items() if k != 'labels'})

                # Get predictions
                logits = outputs['logits']
                preds = torch.argmax(logits, dim=1).cpu().numpy()

                all_preds.extend(preds)
                all_logits.append(logits.cpu().numpy())

        # Concatenate all logits
        if all_logits:
            all_logits = np.vstack(all_logits)

        return all_preds, all_logits

    def save_model(self, path):
        """Save model to file"""
        if self.best_model_state is not None:
            # Save only the model state dict to avoid pickle errors in Colab
            torch.save(self.best_model_state['model_state_dict'], path)
            print(f"Model saved to {path}")
        else:
            print("No best model state available to save")

    def load_model(self, path):
        """Load model from file"""
        # Load with map_location to handle device changes
        state_dict = torch.load(path, map_location=self.device)

        # Handle both full checkpoint and state_dict only formats
        if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
            self.model.load_state_dict(state_dict['model_state_dict'])
            self.best_model_state = state_dict
        else:
            self.model.load_state_dict(state_dict)
            self.best_model_state = {'model_state_dict': state_dict}

        print(f"Model loaded from {path}")

    def plot_training_history(self, history, save_path=None):
        """Plot training history metrics"""
        plt.figure(figsize=(15, 10))

        # Plot loss
        plt.subplot(2, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Val Loss')
        plt.title('Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        # Plot accuracy
        plt.subplot(2, 2, 2)
        plt.plot(history['val_accuracy'], label='Accuracy')
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

        # Plot F1 score
        plt.subplot(2, 2, 3)
        plt.plot(history['val_f1'], label='F1 Score')
        plt.title('F1 Score')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.legend()

        # Plot learning rate
        plt.subplot(2, 2, 4)
        plt.plot(history['learning_rates'], label='Learning Rate')
        plt.title('Learning Rate')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.legend()

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path)

        plt.show()

    def plot_confusion_matrix(self, true_labels, predictions, class_names=None, save_path=None):
        """Plot confusion matrix"""
        cm = confusion_matrix(true_labels, predictions)

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names if class_names else range(len(np.unique(true_labels))),
                   yticklabels=class_names if class_names else range(len(np.unique(true_labels))))
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')

        if save_path:
            plt.savefig(save_path)

        plt.show()

# =============================================
# DATA AUGMENTATION
# =============================================

def augment_data(df, target_class, multiplier=0.5):
    """Augment data for specified class"""
    print(f"Performing data augmentation for class {target_class}...")

    # Get samples from target class
    target_samples = df[df['label'] == target_class]

    # Determine number of samples to generate
    num_samples = int(len(target_samples) * multiplier)

    if num_samples == 0:
        return df

    # Select random samples for augmentation
    augmented_samples = target_samples.sample(num_samples, replace=True)

    # Simple text augmentation (word swapping)
    def swap_words(text):
        words = text.split()
        if len(words) <= 3:
            return text

        # Swap 1-2 random pairs of words
        num_swaps = min(2, len(words) // 2)
        for _ in range(num_swaps):
            idx1, idx2 = np.random.choice(range(len(words)), 2, replace=False)
            words[idx1], words[idx2] = words[idx2], words[idx1]

        return ' '.join(words)

    # Apply augmentation
    augmented_samples['comment'] = augmented_samples['comment'].apply(swap_words)

    # Combine with original data
    augmented_df = pd.concat([df, augmented_samples], ignore_index=True)

    print(f"Generated {len(augmented_samples)} augmented samples")
    print(f"New dataset size: {len(augmented_df)}")
    print("New class distribution:")
    print(augmented_df['label'].value_counts(normalize=True) * 100)

    return augmented_df

# =============================================
# COLAB OPTIMIZED BATCH SIZE FINDER
# =============================================

def find_optimal_batch_size(model, tokenizer, preprocessor, sample_texts, sample_labels, max_batch_size=64):
    """Find the optimal batch size for Colab's memory constraints"""
    print("Finding optimal batch size for your environment...")

    # Create a small dataset
    dataset = MarathiSentimentDataset(
        sample_texts[:100],
        sample_labels[:100],
        tokenizer,
        max_length=128,
        preprocessor=preprocessor
    )

    # Try different batch sizes
    for batch_size in [8, 16, 24, 32, 48, 64]:
        if batch_size > max_batch_size:
            break

        try:
            # Create data loader
            loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

            # Try a forward and backward pass
            model.train()
            for batch in loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs['loss']
                loss.backward()
                break

            # Clear memory
            clear_memory()

            # If we got here, this batch size works
            optimal_batch_size = batch_size
            print(f"Batch size {batch_size} works well")
        except RuntimeError as e:
            if 'out of memory' in str(e).lower():
                print(f"Batch size {batch_size} is too large (OOM error)")
                break
            else:
                raise e

    # Return the largest successful batch size
    print(f"Optimal batch size: {optimal_batch_size}")
    return optimal_batch_size

# =============================================
# MAIN FUNCTION
# =============================================

def main():
    # Load dataset
    print("Loading dataset...")
    # Replace with your dataset path or use the default path
    file_path = '/content/drive/MyDrive/combined_dataset.csv'

    # Allow user to specify a different path
    if IN_COLAB:
        from google.colab import files
        import ipywidgets as widgets
        from IPython.display import display

        use_default = widgets.Checkbox(
            value=True,
            description='Use default dataset path',
            disabled=False
        )

        file_path_widget = widgets.Text(
            value=file_path,
            description='Dataset path:',
            disabled=True
        )

        def on_checkbox_change(change):
            file_path_widget.disabled = change['new']

        use_default.observe(on_checkbox_change, names='value')

        display(use_default)
        display(file_path_widget)

        # Allow user to upload a file if needed
        print("Or upload a CSV file:")
        try:
            uploaded = files.upload()
            if uploaded:
                file_path = next(iter(uploaded))
                print(f"Using uploaded file: {file_path}")
        except:
            print("File upload not available")

    # Load the dataset
    df = load_dataset(file_path)

    if df is None:
        print("Failed to load dataset. Please check the file path.")
        return

    # Split data into train, validation, and test sets
    print("Splitting data into train, validation, and test sets...")
    train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['label'], random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)

    print(f"Train set: {len(train_df)} samples")
    print(f"Validation set: {len(val_df)} samples")
    print(f"Test set: {len(test_df)} samples")

    # Check class distribution
    class_counts = train_df['label'].value_counts()
    min_class = class_counts.idxmin()

    # Augment data for minority class
    train_df = augment_data(train_df, min_class, multiplier=0.5)

    # Initialize tokenizer
    print("Initializing tokenizer...")
    model_name = "google/muril-base-cased"  # MuRIL is better for Indian languages
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Initialize text preprocessor
    print("Initializing text preprocessor...")
    preprocessor = MarathiTextPreprocessor(remove_stopwords=True)

    # Create datasets
    print("Creating datasets...")
    train_dataset = MarathiSentimentDataset(
        train_df['comment'].values,
        train_df['label'].values,
        tokenizer,
        max_length=128,
        preprocessor=preprocessor
    )

    val_dataset = MarathiSentimentDataset(
        val_df['comment'].values,
        val_df['label'].values,
        tokenizer,
        max_length=128,
        preprocessor=preprocessor
    )

    test_dataset = MarathiSentimentDataset(
        test_df['comment'].values,
        test_df['label'].values,
        tokenizer,
        max_length=128,
        preprocessor=preprocessor
    )

    # Initialize model
    print("Initializing model...")
    num_labels = len(df['label'].unique())
    model = MarathiSentimentClassifier(model_name, num_labels=num_labels)
    model.to(device)

    # Find optimal batch size for Colab
    if IN_COLAB:
        batch_size = find_optimal_batch_size(
            model,
            tokenizer,
            preprocessor,
            train_df['comment'].values,
            train_df['label'].values
        )
    else:
        batch_size = 16  # Default batch size

    # Create data loaders with optimal batch size
    print("Creating data loaders...")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False)

    # Initialize trainer
    print("Initializing trainer...")
    trainer = MarathiSentimentTrainer(
        model=model,
        device=device,
        mixed_precision=True
    )

    # Create optimizer
    print("Creating optimizer and scheduler...")
    optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

    # Create scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=5,  # First restart after 5 epochs
        T_mult=1,  # Subsequent restarts after T_0 epochs
        eta_min=1e-6  # Minimum learning rate
    )

    # Train model
    print("Training model...")
    history = trainer.train(
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        num_epochs=15,
        patience=3,
        clip_grad_norm=1.0
    )

    # Plot training history
    trainer.plot_training_history(history, save_path='training_history.png')

    # Clear memory before evaluation
    clear_memory()

    # Evaluate on test set
    print("\nEvaluating on test set...")
    test_metrics = trainer._evaluate(test_loader)

    print(f"Test Loss: {test_metrics['loss']:.4f}")
    print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Test F1 Score: {test_metrics['f1']:.4f}")

    # Print detailed classification report
    print("\nClassification Report:")
    print(classification_report(
        [label for batch in test_loader for label in batch['labels'].cpu().numpy()],
        trainer.predict(test_loader)[0]
    ))

    # Plot confusion matrix
    test_preds, _ = trainer.predict(test_loader)
    test_labels = [label for batch in test_loader for label in batch['labels'].cpu().numpy()]

    # Map class indices to original labels if available
    if 'original_label' in df.columns:
        label_mapping = {0: -1, 1: 0, 2: 1}  # Reverse of the original mapping
        class_names = [label_mapping.get(i, i) for i in range(num_labels)]
    else:
        class_names = list(range(num_labels))

    trainer.plot_confusion_matrix(
        test_labels,
        test_preds,
        class_names=class_names,
        save_path='confusion_matrix.png'
    )

    # Save model
    print("\nSaving final model...")
    trainer.save_model('marathi_sentiment_model.pt')

    # Save tokenizer and preprocessor info for inference
    import json
    with open('model_info.json', 'w') as f:
        json.dump({
            'model_name': model_name,
            'num_labels': num_labels,
            'class_mapping': label_mapping if 'original_label' in df.columns else None
        }, f)

    # If in Colab, provide download links
    if IN_COLAB:
        from google.colab import files
        files.download('marathi_sentiment_model.pt')
        files.download('model_info.json')
        files.download('training_history.png')
        files.download('confusion_matrix.png')

    print("Training and evaluation completed!")

    # Return model and tokenizer for inference
    return model, tokenizer, preprocessor

# =============================================
# INFERENCE FUNCTION
# =============================================

def predict_sentiment(text, model, tokenizer, preprocessor=None):
    """Predict sentiment for a given text"""
    model.eval()

    # Preprocess text if preprocessor is provided
    if preprocessor:
        text = preprocessor.process(text)

    # Tokenize text
    encoding = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # Get prediction
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs['logits']
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(logits, dim=1).item()

    # Map prediction to sentiment
    sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}

    # Get confidence
    confidence = probs[0][pred].item()

    return {
        'sentiment': sentiment_map.get(pred, str(pred)),
        'confidence': confidence,
        'probabilities': probs[0].cpu().numpy()
    }

# =============================================
# INTERACTIVE DEMO FOR COLAB
# =============================================

def create_interactive_demo(model, tokenizer, preprocessor):
    """Create an interactive demo for Colab"""
    if not IN_COLAB:
        print("Interactive demo is only available in Google Colab")
        return

    try:
        import ipywidgets as widgets
        from IPython.display import display, HTML

        # Create input widget
        text_input = widgets.Textarea(
            value='मला हा चित्रपट खूप आवडला!',
            placeholder='Enter Marathi text here...',
            description='Text:',
            disabled=False,
            layout=widgets.Layout(width='100%', height='100px')
        )

        # Create output widget
        output = widgets.Output()

        # Create button
        button = widgets.Button(
            description='Analyze Sentiment',
            disabled=False,
            button_style='primary',
            tooltip='Click to analyze sentiment',
            icon='check'
        )

        # Define button click handler
        def on_button_clicked(b):
            with output:
                output.clear_output()

                # Get input text
                text = text_input.value

                if not text:
                    print("Please enter some text")
                    return

                # Predict sentiment
                result = predict_sentiment(text, model, tokenizer, preprocessor)

                # Display result
                sentiment = result['sentiment']
                confidence = result['confidence'] * 100

                # Color based on sentiment
                color = 'red' if sentiment == 'Negative' else 'green' if sentiment == 'Positive' else 'gray'

                display(HTML(f"""
                <div style="padding: 10px; border-radius: 5px; background-color: #f5f5f5;">
                    <h3>Sentiment Analysis Result</h3>
                    <p><b>Text:</b> {text}</p>
                    <p><b>Sentiment:</b> <span style="color: {color}; font-weight: bold;">{sentiment}</span></p>
                    <p><b>Confidence:</b> {confidence:.2f}%</p>
                </div>
                """))

                # Display probabilities
                probs = result['probabilities']
                sentiment_labels = ['Negative', 'Neutral', 'Positive']

                plt.figure(figsize=(10, 5))
                plt.bar(sentiment_labels, probs)
                plt.title('Sentiment Probabilities')
                plt.ylabel('Probability')
                plt.ylim(0, 1)
                for i, p in enumerate(probs):
                    plt.text(i, p + 0.02, f'{p:.2f}', ha='center')
                plt.show()

        # Connect button click handler
        button.on_click(on_button_clicked)

        # Display widgets
        display(HTML("<h2>Marathi Sentiment Analysis Demo</h2>"))
        display(text_input)
        display(button)
        display(output)

    except ImportError:
        print("Could not create interactive demo. Make sure you're running in Colab with ipywidgets installed.")

# =============================================
# EXAMPLE USAGE
# =============================================

if __name__ == "__main__":
    # Train model and get components for inference
    model, tokenizer, preprocessor = main()

    # Create interactive demo if in Colab
    if IN_COLAB:
        create_interactive_demo(model, tokenizer, preprocessor)

    # Example inference
    print("\nExample inference:")

    example_texts = [
        "मला हा चित्रपट खूप आवडला!",  # I liked this movie very much!
        "हा उत्पाद सामान्य आहे, काही विशेष नाही.",  # This product is average, nothing special.
        "सेवा खूपच वाईट होती, मी पुन्हा येणार नाही."  # The service was very bad, I won't come again.
    ]

    for text in example_texts:
        result = predict_sentiment(text, model, tokenizer, preprocessor)
        print(f"\nText: {text}")
        print(f"Sentiment: {result['sentiment']}")
        print(f"Confidence: {result['confidence']:.4f}")
