In [None]:
pip install --upgrade bigframes

In [None]:
pip install toolz==0.12.0

In [None]:
pip uninstall rich --yes

In [None]:
pip install rich==13.7.1

In [None]:
pip install --upgrade category-encoders

In [None]:
pip uninstall scikit-learn --yes

In [None]:
pip install scikit-learn==1.6.1

In [None]:
pip install hf_xet

In [None]:
pip install protobuf==3.20.1

In [None]:
pip install --upgrade grpcio-status tensorflow-metadata

In [None]:
# ===== NLTK Resource Setup =====
import nltk
import os

def download_nltk_resources():
    """Download all required NLTK resources if not already present."""
    resources = ['punkt', 'wordnet', 'vader_lexicon', 'opinion_lexicon', 'averaged_perceptron_tagger']
    for resource in resources:
        try:
            nltk.data.find(resource)
            print(f"{resource} already downloaded")
        except LookupError:
            print(f"Downloading {resource}...")
            nltk.download(resource)

download_nltk_resources()

# ===== Imports & Configuration =====
import random
import pickle
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
from transformers import AutoTokenizer, AutoModel
from torchvision.ops import StochasticDepth
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import time
from functools import partial
from torch.amp import autocast, GradScaler
import json
from huggingface_hub import login, hf_hub_download, upload_file


# Configuration parameters
BASE_PATH = '/kaggle/input/optimized-set'
MODEL_SAVE_PATH = './hybrid_sentiment_model.pt'
BERT_MODEL_NAME = 'albert-base-v2'
NUM_CLASSES = 3
MAX_SEQ_LENGTH = 128
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 3e-05
HIDDEN_SIZE = 768
LDA_TOPICS = 25
USE_MIXED_PRECISION = True
GRADIENT_ACCUMULATION_STEPS = 1
EARLY_STOPPING_PATIENCE = 3

# Set device for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# ===== Load Serialized Objects =====
try:
    with open(os.path.join(BASE_PATH, 'count_vectorizer.pkl'), 'rb') as f:
        count_vectorizer = pickle.load(f)
    with open(os.path.join(BASE_PATH, 'lda_model.pkl'), 'rb') as f:
        lda_model = pickle.load(f)
    with open(os.path.join(BASE_PATH, 'class_weights.pkl'), 'rb') as f:
        class_weights = pickle.load(f)
    with open(os.path.join(BASE_PATH, 'sentiment_encoder.pkl'), 'rb') as f:
        sentiment_encoder = pickle.load(f)

    # Optionally extract values for use later on
    VOCAB_SIZE = len(count_vectorizer.vocabulary_)
    NUM_TOPICS = lda_model.n_components
    NUM_CLASSES = len(sentiment_encoder.classes_)

    print(f"Vocabulary size: {VOCAB_SIZE}")
    print(f"Number of LDA topics: {NUM_TOPICS}")
    print(f"Number of sentiment classes: {NUM_CLASSES}")
except Exception as e:
    print(f"Error loading serialized objects: {e}")
    raise


In [None]:
# ===== Dataset & Data Loading =====
class SentimentDataset(Dataset):
    """
    Dataset class that loads pre-tokenized inputs along with
    multiple LDA topic distributions and labels.
    """
    def __init__(self, input_ids, attention_mask, topic_dist, topic_dist_25=None, labels=None):
        # Find the minimum length across all inputs
        min_len = min(len(input_ids), len(attention_mask), len(topic_dist))
        if labels is not None:
            min_len = min(min_len, len(labels))
        if topic_dist_25 is not None:
            min_len = min(min_len, len(topic_dist_25))
            
        # Truncate all inputs to the minimum length
        self.input_ids = torch.tensor(input_ids[:min_len], dtype=torch.long)
        self.attention_mask = torch.tensor(attention_mask[:min_len], dtype=torch.long)
        self.topic_dist = torch.tensor(topic_dist[:min_len], dtype=torch.float)
        self.has_multi_granularity = topic_dist_25 is not None
        
        if self.has_multi_granularity:
            self.topic_dist_25 = torch.tensor(topic_dist_25[:min_len], dtype=torch.float)
        
        if labels is not None:
            self.labels = torch.tensor(labels[:min_len], dtype=torch.long)
        self.has_labels = labels is not None
        
        print(f"Dataset created with {min_len} samples (reduced from {len(input_ids)} and {len(topic_dist)})")
  
    def __len__(self):
        return len(self.input_ids)
  
    def __getitem__(self, idx):
        item = {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'topic_dist': self.topic_dist[idx]
        }
        
        if self.has_multi_granularity:
            item['topic_dist_25'] = self.topic_dist_25[idx]
            
        if self.has_labels:
            item['labels'] = self.labels[idx]
            
        return item
  
    @staticmethod
    def collate_fn(batch):
        batch_input_ids = torch.stack([item['input_ids'] for item in batch])
        batch_attention = torch.stack([item['attention_mask'] for item in batch])
        batch_topic = torch.stack([item['topic_dist'] for item in batch])
        
        result = {
            'input_ids': batch_input_ids,
            'attention_mask': batch_attention,
            'topic_dist': batch_topic
        }
        
        # Check if multi-granularity topics are available
        if 'topic_dist_25' in batch[0]:
            batch_topic_25 = torch.stack([item['topic_dist_25'] for item in batch])
            result['topic_dist_25'] = batch_topic_25
        
        # Check if labels are available
        if 'labels' in batch[0]:
            batch_labels = torch.stack([item['labels'] for item in batch])
            result['labels'] = batch_labels
            
        return result

def load_csv_data():
    """Load CSV files containing preprocessed text data with metadata"""
    train_csv = pd.read_csv(os.path.join(BASE_PATH, 'train_data_balanced.csv'))
    val_csv = pd.read_csv(os.path.join(BASE_PATH, 'val_data.csv'))
    test_csv = pd.read_csv(os.path.join(BASE_PATH, 'test_data.csv'))
    
    print(f"Loaded CSV data - Train: {len(train_csv)} samples, Val: {len(val_csv)} samples, Test: {len(test_csv)} samples")
    return train_csv, val_csv, test_csv

# ===== Dataset & Data Loading =====
def load_data(split, include_multi_granularity=True):
    """
    Load preprocessed NumPy arrays for a given split with multi-granularity topics.
    
    Args:
        split: Dataset split ('train', 'val', or 'test')
        include_multi_granularity: Whether to include 25-topic distributions
        
    Returns:
        Tuple of arrays including multi-granularity topic distributions if requested
    """
    try:
        # Load base data
        input_ids = np.load(os.path.join(BASE_PATH, f'{split}_input_ids.npy'))
        attention_mask = np.load(os.path.join(BASE_PATH, f'{split}_attention_mask.npy'))
        topic_dist = np.load(os.path.join(BASE_PATH, f'{split}_lda_topics.npy'))
        labels = np.load(os.path.join(BASE_PATH, f'{split}_labels.npy'))
        
        # Calculate minimum size with enhanced reporting for truncation
        original_lengths = [len(input_ids), len(attention_mask), len(topic_dist), len(labels)]
        min_size = min(original_lengths)
        if min_size < max(original_lengths):
            truncation_percentage = (max(original_lengths) - min_size) / max(original_lengths) * 100
            print(f"Warning: Truncating {split} data by {truncation_percentage:.2f}% due to length mismatch.")
            if truncation_percentage > 5:
                print(f"Critical Warning: Significant data loss (>5%) in {split} data due to truncation!")
        
        print(f"Truncating {split} data from {len(input_ids)}/{len(topic_dist)} to {min_size} samples")
        
        # Truncate all arrays
        input_ids = input_ids[:min_size]
        attention_mask = attention_mask[:min_size]
        topic_dist = topic_dist[:min_size]
        labels = labels[:min_size]
        
        # Load and truncate 25-topic distributions if requested
        if include_multi_granularity:
            topic_dist_25 = np.load(os.path.join(BASE_PATH, f'{split}_lda_topics_25.npy'))
            topic_dist_25 = topic_dist_25[:min_size]
            # Validate LDA topic dimensions
            expected_topics_15 = 15
            expected_topics_25 = 25
            if topic_dist.shape[1] != expected_topics_15 or topic_dist_25.shape[1] != expected_topics_25:
                print(f"Warning: LDA topic dimensions mismatch in {split} data! Expected {expected_topics_15} and {expected_topics_25}, "
                      f"got {topic_dist.shape[1]} and {topic_dist_25.shape[1]}.")
            return input_ids, attention_mask, topic_dist, topic_dist_25, labels
        else:
            # Validate single LDA topic dimension
            expected_topics_15 = 15
            if topic_dist.shape[1] != expected_topics_15:
                print(f"Warning: LDA topic dimension mismatch in {split} data! Expected {expected_topics_15}, "
                      f"got {topic_dist.shape[1]}.")
            return input_ids, attention_mask, topic_dist, labels
    except Exception as e:
        print(f"Error loading data for {split}: {e}")
        raise


In [None]:
# ===== Model Components =====
class MultiGranularityTopicModule(nn.Module):
    """
    Enhanced topic module that processes LDA distributions at multiple granularities.
    Combines features from both 15-topic and 25-topic distributions for richer
    symbolic representations.
    """
    def __init__(self, lda_topics=15, lda_topics_25=25, hidden_size=HIDDEN_SIZE, dropout_rate=0.1):
        super().__init__()
        # Make sure first parameter matches topic_dist dimension (15)
        self.topic_encoder = nn.Sequential(
            nn.Linear(lda_topics, hidden_size // 4),
            nn.BatchNorm1d(hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2)
        )
        
        # Additional encoder for 25-topic distributions
        self.topic_encoder_25 = nn.Sequential(
            nn.Linear(lda_topics_25, hidden_size // 4),
            nn.BatchNorm1d(hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2)
        )
        
        # Final projection to combine features
        self.projection = nn.Linear(hidden_size, hidden_size)
        
        # Xavier initialization for linear layers
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, topic_dist, topic_dist_25=None):
        # Process standard topics
        topic_dist = torch.clamp(topic_dist, min=0.05)
        topic_dist = topic_dist / torch.sum(topic_dist, dim=1, keepdim=True)
        uniform_dist = torch.ones_like(topic_dist) / topic_dist.size(1)
        topic_dist = 0.7 * topic_dist + 0.3 * uniform_dist
        features_15 = self.topic_encoder(topic_dist)
        
        if topic_dist_25 is not None:
            # Process 25-topic distributions
            topic_dist_25 = torch.clamp(topic_dist_25, min=0.05)
            topic_dist_25 = topic_dist_25 / torch.sum(topic_dist_25, dim=1, keepdim=True)
            uniform_dist_25 = torch.ones_like(topic_dist_25) / topic_dist_25.size(1)
            topic_dist_25 = 0.7 * topic_dist_25 + 0.3 * uniform_dist_25
            features_25 = self.topic_encoder_25(topic_dist_25)
            
            # Concatenate features and project to final dimension
            combined_features = torch.cat([features_15, features_25], dim=1)
            return self.projection(combined_features)
        else:
            # Use zero padding instead of duplicating features
            features_placeholder = torch.zeros_like(features_15)
            combined_features = torch.cat([features_15, features_placeholder], dim=1)
            return self.projection(combined_features)

class UFENModule(nn.Module):
    """
    Extracts contextual text features using a pre-trained transformer.
    Focuses solely on text processing without early metadata fusion.
    """
    def __init__(self, hidden_size=HIDDEN_SIZE, metadata_dim=NUM_TOPICS, use_distil=True):
        super().__init__()
        self.bert = AutoModel.from_pretrained(BERT_MODEL_NAME)
        for param in self.bert.parameters():
            param.requires_grad = False
        if hasattr(self.bert.encoder, 'albert_layer_groups'):
            last_group_idx = len(self.bert.encoder.albert_layer_groups) - 1
            for param in self.bert.encoder.albert_layer_groups[last_group_idx].parameters():
                param.requires_grad = True
        elif hasattr(self.bert.encoder, 'layer'):
            for param in self.bert.encoder.layer[-1].parameters():
                param.requires_grad = True
        
        self.projection = nn.Linear(self.bert.config.hidden_size, hidden_size)
        
    def forward(self, input_ids, attention_mask, metadata=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = bert_output.last_hidden_state
        text_features = self.projection(text_features)
        mask = attention_mask.unsqueeze(-1).float()
        text_features = text_features * mask
        return text_features

class ECOSAMModule(nn.Module):
    """
    Uses multihead attention to focus on sentiment-specific signals.
    Incorporates an additional learnable sentiment query and gating mechanism.
    """
    def __init__(self, hidden_size=HIDDEN_SIZE, dropout_rate=0.1):
        super().__init__()
        self.bilstm = nn.LSTM(hidden_size, hidden_size//2, bidirectional=True, batch_first=True)
        self.context_attention = nn.MultiheadAttention(hidden_size, num_heads=8, batch_first=True)
        self.attention_norm = nn.LayerNorm(hidden_size)
        self.sentiment_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.01)
        self.sentiment_attention = nn.MultiheadAttention(hidden_size, num_heads=4, batch_first=True)
        self.sent_dropout = nn.Dropout(dropout_rate)
        self.sent_gate = nn.Sequential(
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
        self.context_dropout = nn.Dropout(dropout_rate)
        
    def forward(self, sequence_features):
        # Removed clamping to preserve feature information; rely on LayerNorm for stability
        lstm_out, _ = self.bilstm(sequence_features)
        context_attn, _ = self.context_attention(lstm_out, lstm_out, lstm_out)
        context_attn = self.context_dropout(context_attn)
        context_enhanced = self.attention_norm(context_attn + lstm_out)
        batch_size = sequence_features.size(0)
        sent_query = self.sentiment_query.expand(batch_size, 1, -1)
        sent_context, _ = self.sentiment_attention(sent_query, context_enhanced, context_enhanced)
        sent_context = self.sent_dropout(sent_context).squeeze(1)
        sent_gates = self.sent_gate(context_enhanced)
        sent_gates = sent_gates / (sent_gates.sum(dim=1, keepdim=True) + 1e-9)
        context_vector = (context_enhanced * sent_gates).sum(dim=1)
        return context_vector

class BidirectionalFusion(nn.Module):
    """
    Fuses neural text representations and symbolic topic features.
    Uses learnable linear transformations and a scalar gate to dynamically balance information flow.
    """
    def __init__(self, hidden_size=HIDDEN_SIZE):
        super().__init__()
        self.text_to_topic = nn.Linear(hidden_size, hidden_size)
        self.topic_to_text = nn.Linear(hidden_size, hidden_size)
        self.gate = nn.Parameter(torch.tensor(0.5))
    
    def forward(self, text_features, topic_features):
        text_enhanced = text_features + self.gate * self.topic_to_text(topic_features)
        topic_enhanced = topic_features + (1 - self.gate) * self.text_to_topic(text_features)
        return text_enhanced, topic_enhanced

from huggingface_hub import PyTorchModelHubMixin

class HybridSentimentModel(nn.Module, PyTorchModelHubMixin):
    """
    Enhanced hybrid neuro-symbolic model that uses multi-granularity LDA topics.
    Combines neural text features from transformers with symbolic features 
    from multiple LDA topic distributions.
    """
    def __init__(self, num_classes=NUM_CLASSES, lda_topics=15, lda_topics_25=25, hidden_size=HIDDEN_SIZE, dropout_rate=0.1):
        super().__init__()
        self.config = {
            "num_classes": num_classes,
            "lda_topics": lda_topics,
            "lda_topics_25": lda_topics_25,
            "hidden_size": hidden_size,
            "dropout_rate": dropout_rate,
            "bert_model_name": BERT_MODEL_NAME
        }
        self.ufen = UFENModule(hidden_size=hidden_size, metadata_dim=lda_topics, use_distil=True)
        self.topic_module = MultiGranularityTopicModule(lda_topics, lda_topics_25, hidden_size, dropout_rate=dropout_rate)
        self.ecosam = ECOSAMModule(hidden_size=hidden_size, dropout_rate=dropout_rate)
        self.bidir_fusion = BidirectionalFusion(hidden_size=hidden_size)
        self.drop_path = StochasticDepth(p=dropout_rate, mode='batch')
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size),
            self.drop_path,
            nn.Linear(hidden_size, num_classes)
        )
    
    def forward(self, input_ids, attention_mask, topic_dist, topic_dist_25=None):
        # Process symbolic features from LDA topics using multi-granularity module
        topic_features = self.topic_module(topic_dist, topic_dist_25)
        
        # Extract text features from transformer
        sequence_features = self.ufen(input_ids, attention_mask)
        
        # Apply context-aware attention to get sentiment-focused representation
        context_vector = self.ecosam(sequence_features)
        
        # Use the sentiment-focused context vector for fusion
        text_representation = context_vector
        
        # Fuse representations bidirectionally
        enhanced_text, enhanced_topic = self.bidir_fusion(text_representation, topic_features)
        fused_features = enhanced_text + enhanced_topic
        
        # Classification layer outputs final sentiment logits
        logits = self.classifier(fused_features)
        return logits


In [None]:
# ===== Evaluation Functions =====
def calculate_metrics(y_true, y_pred, class_names=None):
    if torch.is_tensor(y_true):
        y_true = y_true.cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.cpu().numpy()
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "confusion_matrix": cm}

def plot_confusion_matrix(cm, class_names=None):
    plt.figure(figsize=(10, 8))
    if class_names is None:
        class_names = [str(i) for i in range(cm.shape[0])]
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()

def plot_metrics_history(metrics_history):
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 2, 1)
    plt.plot(metrics_history['train_acc'], label='Train')
    plt.plot(metrics_history['val_acc'], label='Validation')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()
    plt.subplot(2, 2, 2)
    plt.plot(metrics_history['train_loss'], label='Train')
    plt.plot(metrics_history['val_loss'], label='Validation')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.legend()
    plt.subplot(2, 2, 3)
    plt.plot(metrics_history['train_f1'], label='Train')
    plt.plot(metrics_history['val_f1'], label='Validation')
    plt.title('F1 Score')
    plt.xlabel('Epoch')
    plt.legend()
    plt.subplot(2, 2, 4)
    plt.plot(metrics_history['val_precision'], label='Precision')
    plt.plot(metrics_history['val_recall'], label='Recall')
    plt.title('Precision and Recall')
    plt.xlabel('Epoch')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()


In [None]:
def analyze_csv_data(csv_data, topic_model, vectorizer):
    """
    Analyze text data and topic distributions from CSV files
    to gain additional insights about the sentiment model.
    """
    # Extract text samples with their topic distributions
    text_samples = csv_data['text'].values[:5]
    
    # Use the LDA model to regenerate topic distributions for verification
    text_dtm = vectorizer.transform(text_samples)
    topic_distributions = topic_model.transform(text_dtm)
    
    # Compare topics across different granularities
    feature_names = vectorizer.get_feature_names_out()
    
    # Get top words for each topic
    for topic_idx, topic in enumerate(topic_model.components_[:3]):
        top_words_idx = topic.argsort()[:-11:-1]
        top_words = [feature_names[i] for i in top_words_idx]
        print(f"Topic #{topic_idx}: {', '.join(top_words)}")
    
    return {
        "sample_topics": topic_distributions,
        "topic_words": {i: [feature_names[idx] for idx in topic.argsort()[:-11:-1]] 
                       for i, topic in enumerate(topic_model.components_[:5])}
    }


In [None]:
# ===== Enhanced Training Functions with Mixed Precision and Early Stopping =====
def train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS, lr=LEARNING_RATE):
    # Set up mixed precision training
    scaler = GradScaler(enabled=USE_MIXED_PRECISION) if torch.cuda.is_available() else None
    
    if isinstance(model, nn.DataParallel):
        num_gpus = torch.cuda.device_count()
        adjusted_lr = lr * (num_gpus ** 0.5)
        print(f"Adjusting learning rate for {num_gpus}-GPU training: {adjusted_lr:.6f}")
        lr = adjusted_lr
    
    model.to(device)
    
    # Initialize optimizer and scheduler with weight decay from best_config
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=best_config["weight_decay"])
    
    # Use OneCycleLR for better convergence
    steps_per_epoch = len(train_loader) // GRADIENT_ACCUMULATION_STEPS
    scheduler = OneCycleLR(
        optimizer, 
        max_lr=lr,
        epochs=num_epochs,
        steps_per_epoch=steps_per_epoch,
        pct_start=0.3,
        anneal_strategy='cos',
        div_factor=25.0,
        final_div_factor=10000.0,
    )
    
    # Validate class weights length
    if len(class_weights) != NUM_CLASSES:
        print(f"Warning: class_weights length ({len(class_weights)}) doesn't match NUM_CLASSES ({NUM_CLASSES}).")
        # Fallback to uniform weights if mismatch
        class_weights_tensor = torch.ones(NUM_CLASSES, dtype=torch.float).to(device)
    else:
        class_weights_tensor = torch.tensor([class_weights[i] for i in range(NUM_CLASSES)], dtype=torch.float).to(device)
    weighted_loss = nn.CrossEntropyLoss(weight=class_weights_tensor)
    
    # Initialize tracking variables
    best_val_acc = 0.0
    best_f1 = 0.0
    patience_counter = 0
    
    try:
        class_names = sentiment_encoder.classes_
    except:
        class_names = [str(i) for i in range(NUM_CLASSES)]
    
    metrics_history = {
        'train_loss': [], 'train_acc': [], 'train_f1': [],
        'val_loss': [], 'val_acc': [], 'val_precision': [],
        'val_recall': [], 'val_f1': []
    }
    
    epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", unit="epoch")
    
    for epoch in epoch_pbar:
        model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0
        all_train_preds = []
        all_train_labels = []
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False, unit="batch")
        optimizer.zero_grad()  # Zero gradients at the beginning of epoch for gradient accumulation
        
        for batch_idx, batch in enumerate(train_pbar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            topic_dist = batch['topic_dist'].to(device)
            topic_dist_25 = batch['topic_dist_25'].to(device) if 'topic_dist_25' in batch else None
            labels = batch['labels'].to(device)
            
            # Forward pass with mixed precision
            if USE_MIXED_PRECISION and torch.cuda.is_available():
                with autocast('cuda'):
                    outputs = model(input_ids, attention_mask, topic_dist, topic_dist_25)
                    loss = weighted_loss(outputs, labels)
                    # Scale loss for gradient accumulation
                    loss = loss / GRADIENT_ACCUMULATION_STEPS
                
                # Backward pass with gradient scaling
                scaler.scale(loss).backward()
                
                # Update weights if we've accumulated enough gradients
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
            else:
                # Standard precision training
                outputs = model(input_ids, attention_mask, topic_dist, topic_dist_25)
                loss = weighted_loss(outputs, labels)
                loss = loss / GRADIENT_ACCUMULATION_STEPS
                loss.backward()
                
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
            
            # Track metrics
            batch_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS  # Unadjust loss for reporting
            total_loss += batch_loss * labels.size(0)
            
            with torch.no_grad():
                _, preds = torch.max(outputs, dim=1)
                total_correct += (preds == labels).sum().item()
                total_samples += labels.size(0)
                all_train_preds.extend(preds.cpu().numpy())
                all_train_labels.extend(labels.cpu().numpy())
            
            batch_acc = (preds == labels).sum().item() / labels.size(0)
            train_pbar.set_postfix(loss=f"{batch_loss:.4f}", acc=f"{batch_acc:.4f}")
        
        # Clean up any remaining gradients
        if len(train_loader) % GRADIENT_ACCUMULATION_STEPS != 0:
            scaler.unscale_(optimizer) if scaler else None
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)
            scaler.step(optimizer) if scaler else optimizer.step()
            scaler.update() if scaler else None
            optimizer.zero_grad()
        
        # Calculate epoch metrics
        avg_loss = total_loss / total_samples
        train_acc = total_correct / total_samples
        train_metrics = calculate_metrics(np.array(all_train_labels), np.array(all_train_preds))
        
        # Store metrics
        metrics_history['train_loss'].append(avg_loss)
        metrics_history['train_acc'].append(train_acc)
        metrics_history['train_f1'].append(train_metrics['f1'])
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_samples = 0
        all_val_preds = []
        all_val_labels = []
        
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", leave=False, unit="batch")
        
        with torch.no_grad():
            for batch in val_pbar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                topic_dist = batch['topic_dist'].to(device)
                topic_dist_25 = batch['topic_dist_25'].to(device) if 'topic_dist_25' in batch else None
                labels = batch['labels'].to(device)
                
                # Use mixed precision for validation as well
                if USE_MIXED_PRECISION and torch.cuda.is_available():
                    with autocast('cuda'):
                        outputs = model(input_ids, attention_mask, topic_dist, topic_dist_25)
                        loss = weighted_loss(outputs, labels)
                else:
                    outputs = model(input_ids, attention_mask, topic_dist, topic_dist_25)
                    loss = weighted_loss(outputs, labels)
                
                val_loss += loss.item() * labels.size(0)
                _, preds = torch.max(outputs, dim=1)
                val_correct += (preds == labels).sum().item()
                val_samples += labels.size(0)
                all_val_preds.extend(preds.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
                
                batch_acc = (preds == labels).sum().item() / labels.size(0)
                val_pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{batch_acc:.4f}")
        
        # Calculate validation metrics
        avg_val_loss = val_loss / val_samples
        val_acc = val_correct / val_samples
        val_metrics = calculate_metrics(np.array(all_val_labels), np.array(all_val_preds))
        
        # Store metrics
        metrics_history['val_loss'].append(avg_val_loss)
        metrics_history['val_acc'].append(val_acc)
        metrics_history['val_precision'].append(val_metrics['precision'])
        metrics_history['val_recall'].append(val_metrics['recall'])
        metrics_history['val_f1'].append(val_metrics['f1'])
        
        # Update progress bar
        epoch_pbar.set_postfix(
            train_loss=f"{avg_loss:.4f}", 
            train_f1=f"{train_metrics['f1']:.4f}", 
            val_f1=f"{val_metrics['f1']:.4f}"
        )
        
        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Train - Loss: {avg_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_metrics['f1']:.4f}")
        print(f"  Val   - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}, Precision: {val_metrics['precision']:.4f}, " 
              f"Recall: {val_metrics['recall']:.4f}, F1: {val_metrics['f1']:.4f}")
        
        # Check if this is the best model with combined metric for early stopping
        current_metric = 0.7 * val_metrics['f1'] + 0.3 * (val_metrics['precision'] + val_metrics['recall'])/2
        if current_metric > best_f1:
            best_f1 = current_metric
            patience_counter = 0  # Reset patience counter
            
            # Save the model
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), MODEL_SAVE_PATH)
            else:
                torch.save(model.state_dict(), MODEL_SAVE_PATH)
            
            print(f"New best model saved with combined metric (F1-based): {best_f1:.4f}")
            
            # Plot confusion matrix for best model
            cm = val_metrics['confusion_matrix']
            plot_confusion_matrix(cm, class_names)
        else:
            patience_counter += 1
            print(f"Combined metric did not improve. Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")
            
            # Early stopping check
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
            torch.cuda.empty_cache()
        # Clean up GPU memory
        torch.cuda.empty_cache()
    
    print(f"Training completed. Best combined metric (F1-based): {best_f1:.4f}")
    plot_metrics_history(metrics_history)
    torch.cuda.empty_cache()
    # Load the best model for return
    best_model_state = torch.load(MODEL_SAVE_PATH)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(best_model_state)
    else:
        model.load_state_dict(best_model_state)
    
    return model

def evaluate_model(model, test_loader, class_names=None):
    model.eval()
    if class_names is None:
        try:
            class_names = sentiment_encoder.classes_
        except:
            class_names = [str(i) for i in range(NUM_CLASSES)]
    
    all_preds = []
    all_labels = []
    
    # Add progress bar for test evaluation
    test_pbar = tqdm(test_loader, desc="Evaluating on Test Set", unit="batch")
    
    with torch.no_grad():
        for batch in test_pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            topic_dist = batch['topic_dist'].to(device)
            topic_dist_25 = batch['topic_dist_25'].to(device) if 'topic_dist_25' in batch else None
            labels = batch['labels'].to(device)
            
            # Use mixed precision for evaluation
            if USE_MIXED_PRECISION and torch.cuda.is_available():
                with autocast('cuda'):
                    outputs = model(input_ids, attention_mask, topic_dist, topic_dist_25)
            else:
                outputs = model(input_ids, attention_mask, topic_dist, topic_dist_25)
            
            _, preds = torch.max(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update test progress bar
            batch_acc = (preds == labels).sum().item() / labels.size(0)
            test_pbar.set_postfix(acc=f"{batch_acc:.4f}")
    
    # Calculate all metrics
    metrics = calculate_metrics(np.array(all_labels), np.array(all_preds))
    
    print("\nTest Set Evaluation:")
    print(f"Accuracy:  {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall:    {metrics['recall']:.4f}")
    print(f"F1 Score:  {metrics['f1']:.4f}")
    
    # Plot confusion matrix
    cm = metrics['confusion_matrix']
    plot_confusion_matrix(cm, class_names)
    
    # Per-class metrics for detailed analysis
    classes = np.unique(all_labels)
    print("\nPer-class metrics:")
    for cls in classes:
        cls_indices = np.array(all_labels) == cls
        cls_preds = np.array(all_preds)[cls_indices]
        cls_true = np.array(all_labels)[cls_indices]
        cls_acc = accuracy_score(cls_true, cls_preds)
        cls_prec = precision_score(cls_true, cls_preds, average='macro', zero_division=0)
        cls_rec = recall_score(cls_true, cls_preds, average='macro', zero_division=0)
        cls_f1 = f1_score(cls_true, cls_preds, average='macro', zero_division=0)
        
        print(f"Class {class_names[cls]}:")
        print(f"  Accuracy: {cls_acc:.4f}, Precision: {cls_prec:.4f}, Recall: {cls_rec:.4f}, F1: {cls_f1:.4f}")
        torch.cuda.empty_cache()
    return metrics


In [None]:
# ===== Main Execution =====
if __name__ == "__main__":
    # Fix for "can only test a child process" error
    import torch.multiprocessing as mp
    mp.set_start_method('spawn', force=True)
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"Using device: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # Load CSV data for additional analysis
    train_csv, val_csv, test_csv = load_csv_data()
    
    print("Analyzing topic distributions in CSV data...")
    topic_columns = [col for col in train_csv.columns if col.startswith('topic_')]
    if topic_columns:
        print(f"Found {len(topic_columns)} topic columns in CSV data")
    
    # Use pre-determined hyperparameters instead of tuning
    best_config = {
        'lr': 3e-05, 
        'weight_decay': 0.01, 
        'hidden_size': 768, 
        'dropout_rate': 0.1, 
        'batch_size': 16
    }
    
    print("\n==== Using Pre-determined Hyperparameters ====")
    print(f"Config: {best_config}")
    
    # Load data for training
    train_input_ids, train_attention, train_topic, train_topic_25, train_labels = load_data('train', include_multi_granularity=True)
    val_input_ids, val_attention, val_topic, val_topic_25, val_labels = load_data('val', include_multi_granularity=True)
    test_input_ids, test_attention, test_topic, test_topic_25, test_labels = load_data('test', include_multi_granularity=True)
    
    # Print shape information for verification
    print(f"Train topics shape: {train_topic.shape}, Train topics 25 shape: {train_topic_25.shape}")
    
    # Create datasets with multi-granularity topics
    train_dataset = SentimentDataset(train_input_ids, train_attention, train_topic, train_topic_25, train_labels)
    val_dataset = SentimentDataset(val_input_ids, val_attention, val_topic, val_topic_25, val_labels)
    test_dataset = SentimentDataset(test_input_ids, test_attention, test_topic, test_topic_25, test_labels)
    
    # Batch size from config
    effective_batch_size = best_config["batch_size"]
    if torch.cuda.device_count() >= 2:
        effective_batch_size *= torch.cuda.device_count()
        print(f"Using effective batch size of {effective_batch_size} with {torch.cuda.device_count()} GPUs")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=effective_batch_size,
        shuffle=True,
        collate_fn=SentimentDataset.collate_fn,
        num_workers=0,     
        pin_memory=True,
        persistent_workers=False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=effective_batch_size,
        shuffle=False,
        collate_fn=SentimentDataset.collate_fn,
        num_workers=0,      
        pin_memory=True,
        persistent_workers=False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=effective_batch_size,
        shuffle=False,
        collate_fn=SentimentDataset.collate_fn,
        num_workers=0,    
        pin_memory=True,
        persistent_workers=False
    )

    # Initialize the model with specified hyperparameters
    model = HybridSentimentModel(
        num_classes=NUM_CLASSES, 
        lda_topics=train_topic.shape[1], 
        lda_topics_25=train_topic_25.shape[1], 
        hidden_size=best_config["hidden_size"]
    )
    
    # Update dropout rates
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.p = best_config["dropout_rate"]
    
    # Multi-GPU setup if available
    if torch.cuda.device_count() >= 2:
        print(f"Using {torch.cuda.device_count()} GPUs")
        torch.cuda.set_device(0) 
        model = model.to(torch.device('cuda:0'))
        model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
    else:
        print(f"Using single GPU: {torch.cuda.get_device_name(0)}")
        model = model.to(device)
    
    # Train the model with specified hyperparameters
    trained_model = train_model(model, train_loader, val_loader, 
                               num_epochs=NUM_EPOCHS, 
                               lr=best_config["lr"])
    
    # Evaluate on test set
    try:
        class_names = sentiment_encoder.classes_
    except:
        class_names = [str(i) for i in range(NUM_CLASSES)]
        
    test_metrics = evaluate_model(trained_model, test_loader, class_names)

In [None]:
from huggingface_hub import login

login(token=HF_TOKEN)
model = HybridSentimentModel(
    num_classes=NUM_CLASSES,  
    lda_topics=train_topic.shape[1], 
    lda_topics_25=train_topic_25.shape[1],  
    hidden_size=768,
    dropout_rate=0.1
)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

# Push the model to Hugging Face Hub
model.push_to_hub("aiguy68/neurosam-model")

# Also push the tokenizer
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
tokenizer.push_to_hub("aiguy68/neurosam-model")

# Upload LDA components (if needed)
from huggingface_hub import upload_file

# Save locally first
import pickle
with open("count_vectorizer.pkl", "wb") as f:
    pickle.dump(count_vectorizer, f)
with open("lda_model.pkl", "wb") as f:
    pickle.dump(lda_model, f)

# Upload to Hub
upload_file(
    path_or_fileobj="count_vectorizer.pkl",
    path_in_repo="count_vectorizer.pkl",
    repo_id="aiguy68/neurosam-model"
)
upload_file(
    path_or_fileobj="lda_model.pkl", 
    path_in_repo="lda_model.pkl",
    repo_id="aiguy68/neurosam-model"
)