# PubMed BERT Model + Attention Mechanism


In [29]:
# Cell 1: Imports and Setup
import os, random, json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords
import re



In [30]:
# Cell 2: Random Seed Setup for Reproducibility  
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [31]:
# Cell 3: Device Configuration (GPU/CPU/MPS)
device = (
    torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")

Using device: mps


In [32]:
# Cell 4: Training Hyperparameters
n_epochs = 5
learning_rate = 5e-6  
batch_size = 4

In [33]:
# Cell 5: Constants and Label Definitions
# Define polarity labels
POLARITY_LABELS = ['positive', 'negative', 'neutral']  


In [34]:
# Cell 6: Directory Creation for Models and Results
MODEL_DIR = "model"
os.makedirs(MODEL_DIR, exist_ok=True)  # Create the base model directory
os.makedirs("results", exist_ok=True)

# 2. Utility Functions and Classes

In [35]:
# Cell 7: Text Preprocessing Function

# Download stopwords if not already present
nltk.download('stopwords', quiet=True)
stop_words = set(stopwords.words('english'))


def preprocess_text(text):
    """
    Preprocess text for biomedical NLP by cleaning and normalizing
    
    Args:
        text: Raw text string to preprocess
        
    Returns:
        Cleaned and normalized text string
    """
    # Handle NaN values
    if pd.isna(text):
        return ""
    
    # Convert to string and lowercase
    text = str(text).lower()
    
    # Keep hyphens as they may be important in biomedical terms (e.g., auto-regulation)
    text = re.sub(r'[^\w\s-]', '', text)
    
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Remove stopwords but keep important biomedical terms
    # Note: We're being conservative with stopword removal for biomedical text
    text = " ".join([word.strip() for word in text.split() if word not in stop_words or len(word) > 4])
    
    return text.strip()


In [36]:

# Cell 8: Dataset Class Definition  
class PubMedDataset(Dataset):
    """
    Dataset class for PubMed text classification with optional polarity labels
    
    This class handles both mechanism detection and polarity classification tasks
    """
    
    def __init__(self, texts, labels, tokenizer, max_length=512, polarities=None):
        """
        Initialize the dataset
        
        Args:
            texts: List or Series of text samples
            labels: numpy array of multi-label binary labels for mechanisms
            tokenizer: Transformers tokenizer for text encoding
            max_length: Maximum sequence length for tokenization
            polarities: Optional numpy array of polarity labels (for multi-task learning)
        """
        self.texts = texts.reset_index(drop=True) if hasattr(texts, 'reset_index') else texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.polarities = polarities
        
        # Validate inputs
        if len(self.texts) != len(self.labels):
            raise ValueError(f"Text count ({len(self.texts)}) doesn't match label count ({len(self.labels)})")
        
        if self.polarities is not None and len(self.texts) != len(self.polarities):
            raise ValueError(f"Text count ({len(self.texts)}) doesn't match polarity count ({len(self.polarities)})")
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        """
        Get a single item from the dataset
        
        Args:
            idx: Index of the item to retrieve
            
        Returns:
            Dictionary containing tokenized inputs and labels
        """
        # Get text and preprocess it
        text = str(self.texts.iloc[idx] if hasattr(self.texts, 'iloc') else self.texts[idx])
        text = preprocess_text(text)  # Apply preprocessing
        
        # Tokenize the text
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Prepare the return dictionary
        item = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(self.labels[idx])  # Convert to float for BCE loss
        }
        
        # Add polarity labels if available
        if self.polarities is not None:
            item['polarity'] = torch.LongTensor([self.polarities[idx]]).squeeze()  # Convert to long for CE loss
        
        return item

In [37]:
# Cell 9: Loss Function Classes
class CombinedLoss(nn.Module):
    def __init__(self, pos_weights, gamma=0.5, alpha=0.8):
        super(CombinedLoss, self).__init__()
        self.pos_weights = pos_weights
        self.gamma = gamma
        self.alpha = alpha  # Higher alpha = more weight on BCE
        
    def forward(self, inputs, targets):
        # Weighted BCE loss
        BCE_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, pos_weight=self.pos_weights, reduction='none'
        )
        
        # Focal component (lighter weight)
        pt = torch.exp(-BCE_loss)
        focal_component = (1 - pt) ** self.gamma * BCE_loss
        
        # Combine both losses (80% BCE, 20% Focal)
        combined = self.alpha * BCE_loss + (1 - self.alpha) * focal_component
        
        return combined.mean()
    
class MultiTaskLoss(nn.Module):
    def __init__(self, pos_weights, mech_weight=0.7, pol_weight=0.3):
        super(MultiTaskLoss, self).__init__()
        self.pos_weights = pos_weights
        self.mech_weight = mech_weight
        self.pol_weight = pol_weight
        self.mech_criterion = CombinedLoss(pos_weights)
        self.pol_criterion = nn.CrossEntropyLoss()
        
    def forward(self, mech_logits, pol_logits, mech_labels, pol_labels):
        mech_loss = self.mech_criterion(mech_logits, mech_labels)
        pol_loss = self.pol_criterion(pol_logits, pol_labels)
        
        # Combine mechanism and polarity losses
        total_loss = self.mech_weight * mech_loss + self.pol_weight * pol_loss
        
        return total_loss, mech_loss, pol_loss

## 3. Data Loading and Preprocessing

In [38]:
# Cell 10: Data Loading
# Load data
df = pd.read_csv('../data/processed/train_data.csv')
print(df.shape)
df.head()

(26205, 7)


Unnamed: 0,AC,PMID,Title,Abstract,Terms,Text_combined,batch_number
0,P40416,31040179,Mitochondria export iron-sulfur and sulfur int...,Iron-sulfur clusters are essential cofactors o...,,Mitochondria export iron-sulfur and sulfur int...,1
1,P55511,19376903,Rhizobium sp. strain NGR234 possesses a remark...,Rhizobium sp. strain NGR234 is a unique alphap...,,Rhizobium sp. strain NGR234 possesses a remark...,1
2,Q18G63,16820047,The genome of the square archaeon Haloquadratu...,The square halophilic archaeon Haloquadratum w...,,The genome of the square archaeon Haloquadratu...,1
3,O64682,10693763,Regulation of auxin response by the protein ki...,Arabidopsis plants carrying mutations in the P...,autophosphorylation,Regulation of auxin response by the protein ki...,1
4,P63097,3094012,Molecular cloning and characterization of cDNA...,We have cloned and characterized cDNA encoding...,,Molecular cloning and characterization of cDNA...,1


In [39]:
# Cell 11: Text Preprocessing and Cleaning
df['Text_Cleaned'] = df['Text_combined'].apply(preprocess_text)


In [40]:
# Cell 12: Mechanism Label Binarization
# Convert comma-separated terms to multi-label binary format using MLB

# Binarize the Terms column
df['Terms_List'] = df['Terms'].apply(
    lambda x: [term.strip() for term in str(x).split(',')] if pd.notna(x) and x != '' else []
)

# Initialize and fit the MultiLabelBinarizer
mlb = MultiLabelBinarizer()
binary_labels = mlb.fit_transform(df['Terms_List'])

# Get the class names
label_columns = mlb.classes_
print(f"Found {len(label_columns)} unique labels: {label_columns}")

# Create a DataFrame with the binary labels
labels_df = pd.DataFrame(binary_labels, columns=label_columns)

# Save label columns for later use
with open('label_columns.json', 'w') as f:
    json.dump(list(label_columns), f)

# Keep only essential columns
df_cleaned = df[['batch_number', 'Text_Cleaned']].copy()
df_cleaned = pd.concat([df_cleaned, labels_df], axis=1)

print(f"Final cleaned data shape: {df_cleaned.shape}")
print(df_cleaned.shape)
df_cleaned.head()

Found 10 unique labels: ['autoactivation' 'autocatalysis' 'autofeedback' 'autoinduction'
 'autoinhibition' 'autokinase' 'autolysis' 'autophosphorylation'
 'autoregulation' 'autoubiquitination']
Final cleaned data shape: (26205, 12)
(26205, 12)


Unnamed: 0,batch_number,Text_Cleaned,autoactivation,autocatalysis,autofeedback,autoinduction,autoinhibition,autokinase,autolysis,autophosphorylation,autoregulation,autoubiquitination
0,1,mitochondria export iron-sulfur sulfur interme...,0,0,0,0,0,0,0,0,0,0
1,1,rhizobium sp strain ngr234 possesses remarkabl...,0,0,0,0,0,0,0,0,0,0
2,1,genome square archaeon haloquadratum walsbyi l...,0,0,0,0,0,0,0,0,0,0
3,1,regulation auxin response protein kinase pinoi...,0,0,0,0,0,0,0,1,0,0
4,1,molecular cloning characterization cdna encodi...,0,0,0,0,0,0,0,0,0,0


In [41]:
# Cell 13: Polarity Label Inference and Encoding
# Add polarity inference function
def infer_polarity(text, mechanism):
    """
    Infer polarity (positive/negative/neutral) from text and mechanism
    
    This is a rule-based method that can be later replaced with manual annotations
    """
    # EXPANDED: Keywords indicating negative regulation
    negative_keywords = [
        # Your original keywords
        'inhibit', 'repress', 'suppress', 'block', 'reduce', 'decrease', 'down-regulat', 
        'downregulat', 'negative', 'inactivat', 'stop', 'prevent',
        
        # Additional biomedical negative terms
        'attenuate', 'dampen', 'silence', 'knock', 'impair', 'abolish', 'diminish',
        'weaken', 'curtail', 'halt', 'terminate', 'cease', 'limit', 'restrict',
        'degradation', 'breakdown', 'turnover', 'cleavage', 'proteolysis',
        'downmodulat', 'counter', 'antagoniz', 'oppose'
    ]
    
    # EXPANDED: Keywords indicating positive regulation  
    positive_keywords = [
        # Your original keywords
        'activat', 'increas', 'induce', 'enhance', 'promot', 'stimulat', 'up-regulat',
        'upregulat', 'positive', 'amplif',
        
        # Additional biomedical positive terms
        'boost', 'augment', 'facilitate', 'accelerate', 'catalyze', 'drive', 'trigger',
        'elicit', 'evoke', 'potentiat', 'strengthen', 'reinforce', 'foster', 'support',
        'maintain', 'sustain', 'stabiliz', 'preserve', 'accumul', 'recruit',
        'upmodulat', 'elevat', 'heighten', 'agoniz'
    ]
    
    # Convert text to lowercase for matching
    text_lower = text.lower()
    
    # Count keyword occurrences (more robust than just presence)
    negative_count = sum(1 for keyword in negative_keywords if keyword in text_lower)
    positive_count = sum(1 for keyword in positive_keywords if keyword in text_lower)
    
    # Determine polarity based on keyword balance
    if negative_count > positive_count:
        return 'negative'
    elif positive_count > negative_count:
        return 'positive'
    else:
        # If balanced or no keywords, use mechanism-based heuristics
        if mechanism in ['autoinhibition', 'autorepression']:  # Typically negative
            return 'negative'
        elif mechanism in ['autoactivation', 'autophosphorylation', 'autoinduction']:  # Typically positive
            return 'positive'
        else:
            return 'neutral'  # Default
        
# Add polarity labels to dataset
print("Inferring polarity labels...")
polarities = []

for idx, row in df.iterrows():
    text = row['Text_combined'] if pd.notna(row['Text_combined']) else ""
    
    # Get the mechanisms for this example
    mechanisms = row['Terms_List']
    
    # If no mechanisms, assign neutral
    if not mechanisms:
        polarities.append('neutral')
    else:
        # Get the first mechanism (for multi-labeled entries)
        mechanism = mechanisms[0] if mechanisms else ""
        polarity = infer_polarity(text, mechanism)
        polarities.append(polarity)

# Add to dataframe
df['polarity'] = polarities

# Encode polarity labels
polarity_encoder = LabelEncoder()
polarity_encoder.fit(POLARITY_LABELS)  # Use our predefined labels
encoded_polarities = polarity_encoder.transform(df['polarity'])

# Add to cleaned dataframe
df_cleaned['polarity'] = df['polarity']
df_cleaned['polarity_encoded'] = encoded_polarities

# Save polarity encoder classes
with open('polarity_labels.json', 'w') as f:
    json.dump(list(polarity_encoder.classes_), f)

# Display polarity distribution
polarity_counts = df['polarity'].value_counts()
print("\nPolarity distribution:")
for pol, count in polarity_counts.items():
    print(f"{pol}: {count} ({count/len(df)*100:.1f}%)")


Inferring polarity labels...

Polarity distribution:
neutral: 18200 (69.5%)
positive: 5540 (21.1%)
negative: 2465 (9.4%)


In [42]:
# Cell 14: Final Dataset 
print(f"Final cleaned data shape with polarity: {df_cleaned.shape}")
df_cleaned.head()

Final cleaned data shape with polarity: (26205, 14)


Unnamed: 0,batch_number,Text_Cleaned,autoactivation,autocatalysis,autofeedback,autoinduction,autoinhibition,autokinase,autolysis,autophosphorylation,autoregulation,autoubiquitination,polarity,polarity_encoded
0,1,mitochondria export iron-sulfur sulfur interme...,0,0,0,0,0,0,0,0,0,0,neutral,1
1,1,rhizobium sp strain ngr234 possesses remarkabl...,0,0,0,0,0,0,0,0,0,0,neutral,1
2,1,genome square archaeon haloquadratum walsbyi l...,0,0,0,0,0,0,0,0,0,0,neutral,1
3,1,regulation auxin response protein kinase pinoi...,0,0,0,0,0,0,0,1,0,0,negative,0
4,1,molecular cloning characterization cdna encodi...,0,0,0,0,0,0,0,0,0,0,neutral,1


# 4. Model Architecture

In [43]:
# Cell 15: Base Model Definition
class ImprovedPubMedBERTClassifier(nn.Module):
    def __init__(self, n_classes, dropout1=0.1, dropout2=0.2):
        super(ImprovedPubMedBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
        self.dropout1 = nn.Dropout(dropout1)
        self.intermediate = nn.Linear(self.bert.config.hidden_size, 512)
        self.activation = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout2)
        self.classifier = nn.Linear(512, n_classes)
        
        # Initialize weights
        nn.init.xavier_normal_(self.intermediate.weight)
        nn.init.zeros_(self.intermediate.bias)
        nn.init.xavier_normal_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        pooled_output = self.dropout1(pooled_output)
        intermediate = self.intermediate(pooled_output)
        intermediate = self.activation(intermediate)
        intermediate = self.dropout2(intermediate)
        logits = self.classifier(intermediate)
        return logits


In [44]:

# Cell 16: Multi-Task Model Definition  
class PolarityPubMedBERTClassifier(nn.Module):
    def __init__(self, n_mech_classes, n_polarity_classes=3, dropout1=0.1, dropout2=0.2):
        super(PolarityPubMedBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
        
        # Shared layers
        self.dropout1 = nn.Dropout(dropout1)
        self.intermediate = nn.Linear(self.bert.config.hidden_size, 512)
        self.activation = nn.ReLU()
        
        # Mechanism classification branch
        self.dropout2_mech = nn.Dropout(dropout2)
        self.classifier_mech = nn.Linear(512, n_mech_classes)
        
        # Polarity classification branch
        self.dropout2_pol = nn.Dropout(dropout2)
        self.classifier_pol = nn.Linear(512, n_polarity_classes)
        
        # Initialize weights
        nn.init.xavier_normal_(self.intermediate.weight)
        nn.init.zeros_(self.intermediate.bias)
        nn.init.xavier_normal_(self.classifier_mech.weight)
        nn.init.zeros_(self.classifier_mech.bias)
        nn.init.xavier_normal_(self.classifier_pol.weight)
        nn.init.zeros_(self.classifier_pol.bias)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        pooled_output = self.dropout1(pooled_output)
        
        # Shared intermediate representation
        shared_features = self.activation(self.intermediate(pooled_output))
        
        # Mechanism prediction
        mech_features = self.dropout2_mech(shared_features)
        mech_logits = self.classifier_mech(mech_features)
        
        # Polarity prediction
        pol_features = self.dropout2_pol(shared_features)
        pol_logits = self.classifier_pol(pol_features)
        
        return mech_logits, pol_logits

In [45]:

# Cell 17: Enhanced Single-Task Model (OPTIONAL)
class EnhancedPubMedBERTClassifier(nn.Module):
    def __init__(self, n_classes, dropout1=0.1, dropout2=0.2):
        super(EnhancedPubMedBERTClassifier, self).__init__()
        # Load base PubMedBERT
        self.bert = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
        
        # Add relation attention mechanism
        self.relation_query = nn.Parameter(torch.randn(768, 1))
        self.relation_key = nn.Linear(768, 768)
        self.relation_value = nn.Linear(768, 768)
        
        # Main classification path
        self.dropout1 = nn.Dropout(dropout1)
        self.intermediate = nn.Linear(768 * 2, 512)  # Doubled for concatenation
        self.activation = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout2)
        self.classifier = nn.Linear(512, n_classes)
        
        # Initialize weights properly
        nn.init.xavier_normal_(self.intermediate.weight)
        nn.init.zeros_(self.intermediate.bias)
        nn.init.xavier_normal_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
    
    def forward(self, input_ids, attention_mask):
        # Get PubMedBERT embeddings
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Get sequence outputs
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        cls_token = sequence_output[:, 0, :]  # [batch_size, hidden_size]
        
        # Apply relation attention
        attention_mask_expanded = attention_mask.unsqueeze(-1)  # [batch_size, seq_len, 1]
        
        # Calculate relation-aware attention weights
        relation_keys = self.relation_key(sequence_output)  # [batch_size, seq_len, hidden_size]
        query = self.relation_query.unsqueeze(0).expand(input_ids.size(0), -1, -1)  # [batch_size, hidden_size, 1]
        
        # Get attention scores and mask padding tokens
        attention_scores = torch.bmm(relation_keys, query)  # [batch_size, seq_len, 1]
        attention_scores = attention_scores.masked_fill(attention_mask_expanded == 0, -10000.0)
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch_size, seq_len, 1]
        
        # Get relation-aware context vector
        relation_values = self.relation_value(sequence_output)  # [batch_size, seq_len, hidden_size]
        relation_context = torch.sum(attention_weights * relation_values, dim=1)  # [batch_size, hidden_size]
        
        # Combine CLS token with relation context
        pooled_output = torch.cat([cls_token, relation_context], dim=1)  # [batch_size, hidden_size*2]
        pooled_output = self.dropout1(pooled_output)
        
        # Classification
        intermediate = self.intermediate(pooled_output)
        intermediate = self.activation(intermediate)
        intermediate = self.dropout2(intermediate)
        logits = self.classifier(intermediate)
        
        return logits, attention_weights

In [46]:

# Cell 18: Enhanced Multi-Task Model 
class EnhancedPolarityPubMedBERTClassifier(nn.Module):
    """
    Enhanced multi-task PubMedBERT classifier with relation-aware attention mechanism.
    
    This model performs two tasks:
    1. Mechanism detection (multi-label classification)
    2. Polarity classification (single-label classification)
    
    The enhancement includes a relation attention mechanism that helps the model
    focus on tokens that indicate regulatory relationships, improving detection
    of implicit autoregulatory mechanisms.
    """
    
    def __init__(self, n_mech_classes, n_polarity_classes=3, dropout1=0.1, dropout2=0.2):
        """
        Initialize the enhanced multi-task model
        
        Args:
            n_mech_classes: Number of mechanism classes (e.g., 10 for autoactivation, etc.)
            n_polarity_classes: Number of polarity classes (3: positive, negative, neutral)
            dropout1: Dropout rate after BERT embeddings
            dropout2: Dropout rate before final classifiers
        """
        super(EnhancedPolarityPubMedBERTClassifier, self).__init__()
        
        # Load base PubMedBERT model
        self.bert = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
        
        # Relation-aware attention mechanism
        # This helps the model focus on tokens that indicate regulatory relationships
        self.relation_query = nn.Parameter(torch.randn(768, 1))  # Learnable query for relation detection
        self.relation_key = nn.Linear(768, 768)     # Transform tokens to keys
        self.relation_value = nn.Linear(768, 768)   # Transform tokens to values
        
        # Shared feature processing layers
        self.dropout1 = nn.Dropout(dropout1)
        # Note: Input size is doubled (768 * 2) because we concatenate CLS token + relation context
        self.intermediate = nn.Linear(768 * 2, 512)  
        self.activation = nn.ReLU()
        
        # Mechanism classification branch (multi-label)
        self.dropout2_mech = nn.Dropout(dropout2)
        self.classifier_mech = nn.Linear(512, n_mech_classes)
        
        # Polarity classification branch (single-label)
        self.dropout2_pol = nn.Dropout(dropout2)
        self.classifier_pol = nn.Linear(512, n_polarity_classes)
        
        # Initialize weights using Xavier normal initialization
        self._init_weights()
    
    def _init_weights(self):
        """Initialize model weights properly"""
        nn.init.xavier_normal_(self.intermediate.weight)
        nn.init.zeros_(self.intermediate.bias)
        nn.init.xavier_normal_(self.classifier_mech.weight)
        nn.init.zeros_(self.classifier_mech.bias)
        nn.init.xavier_normal_(self.classifier_pol.weight)
        nn.init.zeros_(self.classifier_pol.bias)
        nn.init.xavier_normal_(self.relation_key.weight)
        nn.init.xavier_normal_(self.relation_value.weight)
    
    def forward(self, input_ids, attention_mask):
        """
        Forward pass through the enhanced multi-task model
        
        Args:
            input_ids: Token IDs from tokenizer [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            Tuple of (mechanism_logits, polarity_logits, attention_weights)
        """
        # Get PubMedBERT embeddings
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Extract sequence outputs and CLS token
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, 768]
        cls_token = sequence_output[:, 0, :]         # [batch_size, 768]
        
        # Apply relation-aware attention mechanism
        attention_weights = self._compute_relation_attention(sequence_output, attention_mask)
        relation_context = self._get_relation_context(sequence_output, attention_weights)
        
        # Combine CLS token with relation context
        # This gives the model both global (CLS) and relation-specific information
        combined_features = torch.cat([cls_token, relation_context], dim=1)  # [batch_size, 768*2]
        combined_features = self.dropout1(combined_features)
        
        # Shared intermediate representation
        shared_features = self.activation(self.intermediate(combined_features))  # [batch_size, 512]
        
        # Mechanism prediction branch (multi-label)
        mech_features = self.dropout2_mech(shared_features)
        mech_logits = self.classifier_mech(mech_features)  # [batch_size, n_mech_classes]
        
        # Polarity prediction branch (single-label)
        pol_features = self.dropout2_pol(shared_features)
        pol_logits = self.classifier_pol(pol_features)     # [batch_size, n_polarity_classes]
        
        return mech_logits, pol_logits, attention_weights
    
    def _compute_relation_attention(self, sequence_output, attention_mask):
        """
        Compute attention weights for relation detection
        
        Args:
            sequence_output: BERT sequence outputs [batch_size, seq_len, 768]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            attention_weights: Attention weights [batch_size, seq_len, 1]
        """
        batch_size, seq_len = attention_mask.shape
        
        # Transform sequence to keys for attention computation
        relation_keys = self.relation_key(sequence_output)  # [batch_size, seq_len, 768]
        
        # Expand the learnable query to match batch size
        query = self.relation_query.unsqueeze(0).expand(batch_size, -1, -1)  # [batch_size, 768, 1]
        
        # Compute attention scores using query-key dot product
        attention_scores = torch.bmm(relation_keys, query)  # [batch_size, seq_len, 1]
        
        # Mask padding tokens by setting their attention scores to very low values
        attention_mask_expanded = attention_mask.unsqueeze(-1)  # [batch_size, seq_len, 1]
        attention_scores = attention_scores.masked_fill(attention_mask_expanded == 0, -10000.0)
        
        # Apply softmax to get attention weights
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch_size, seq_len, 1]
        
        return attention_weights
    
    def _get_relation_context(self, sequence_output, attention_weights):
        """
        Get weighted context vector using attention weights
        
        Args:
            sequence_output: BERT sequence outputs [batch_size, seq_len, 768]
            attention_weights: Attention weights [batch_size, seq_len, 1]
            
        Returns:
            relation_context: Weighted context vector [batch_size, 768]
        """
        # Transform sequence to values
        relation_values = self.relation_value(sequence_output)  # [batch_size, seq_len, 768]
        
        # Compute weighted sum using attention weights
        relation_context = torch.sum(attention_weights * relation_values, dim=1)  # [batch_size, 768]
        
        return relation_context
    
    def get_attention_weights(self, input_ids, attention_mask):
        """
        Get attention weights for visualization (inference only)
        
        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            attention_weights: Attention weights [batch_size, seq_len]
        """
        self.eval()
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            sequence_output = outputs.last_hidden_state
            attention_weights = self._compute_relation_attention(sequence_output, attention_mask)
            return attention_weights.squeeze(-1)  # Remove last dimension [batch_size, seq_len]

# 5. TRAINING AND EVALUATION FUNCTIONS

In [47]:
# Cell 18: Data Loader Creation Functions
def create_moderate_sampler(y):
    """
    Create a weighted sampler with moderate class balancing
    """
    class_sample_count = np.sum(y, axis=0)
    # Square root scaling makes weights less extreme
    weight_per_class = 1.0 / np.sqrt(np.clip(class_sample_count, 5, np.inf))
    
    sample_weights = np.zeros(len(y))
    for i in range(len(y)):
        if np.sum(y[i]) > 0:
            positive_indices = np.where(y[i] == 1)[0]
            sample_weights[i] = np.mean(weight_per_class[positive_indices])
        else:
            # Give negative examples lower weight (they're 67% of dataset)
            sample_weights[i] = 0.5 / max(1, (len(y) - np.sum(np.any(y, axis=1))))
    
    sample_weights = torch.FloatTensor(sample_weights)
    return WeightedRandomSampler(sample_weights, len(sample_weights))

# Create data loader
def create_dataset_and_loader(X, y, batch_size, tokenizer, train=True, polarities=None):
    dataset = PubMedDataset(X, y, tokenizer, polarities=polarities)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=train)
    return loader

# Create data loader with polarity
def create_dataset_and_loader_with_polarity(X, y, polarities, batch_size, tokenizer, train=True):
    dataset = PubMedDataset(X, y, tokenizer, polarities=polarities)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=train)
    return loader

In [48]:

# Cell 19: Training Functions
def train_epoch(model, data_loader, optimizer, criterion, scheduler=None, has_polarity=False):
    """
    Train the model for a single epoch.
    
    Args:
        model: The neural network model to train
        data_loader: DataLoader containing the training data
        optimizer: Optimizer for updating model weights
        criterion: Loss function
        scheduler: Optional learning rate scheduler
        has_polarity: Whether the model includes polarity prediction
        
    Returns:
        Average loss value for the epoch (and component losses if multi-task)
    """
    model.train()
    total_loss = 0
    total_mech_loss = 0
    total_pol_loss = 0
    total_batches = 0
    
    progress_bar = tqdm(data_loader, desc="Training")
    
    for batch in progress_bar:
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            mech_labels = batch['labels'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            
            if has_polarity:
                pol_labels = batch['polarity'].to(device)
                mech_logits, pol_logits = model(input_ids, attention_mask)
                loss, mech_loss, pol_loss = criterion(mech_logits, pol_logits, mech_labels, pol_labels)
                total_mech_loss += mech_loss.item()
                total_pol_loss += pol_loss.item()
            else:
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, mech_labels)
            
            # Check for NaN loss
            if torch.isnan(loss):
                print("WARNING: NaN loss detected, skipping batch")
                continue
                
            # Backward pass
            loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            
            # Update progress bar
            if has_polarity:
                progress_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'mech_loss': f"{mech_loss.item():.4f}",
                    'pol_loss': f"{pol_loss.item():.4f}"
                })
            else:
                progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
            
            total_loss += loss.item()
            total_batches += 1
        except Exception as e:
            print(f"Error in batch processing: {e}")
            continue
    
    # Protect against division by zero
    if total_batches == 0:
        return float('inf')
    
    if has_polarity:
        return (
            total_loss / total_batches,
            total_mech_loss / total_batches,
            total_pol_loss / total_batches
        )
    else:
        return total_loss / total_batches


In [49]:

# Cell 20: Evaluation Functions  
def optimize_thresholds(model, val_loader, n_labels, has_polarity=False):
    model.eval()
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Optimizing thresholds"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            
            if has_polarity:
                model_outputs = model(input_ids, attention_mask)
                if len(model_outputs) == 3:  # Enhanced model
                    outputs, _, _ = model_outputs
                else:  # Regular multi-task model
                    outputs, _ = model_outputs
            else:
                model_outputs = model(input_ids, attention_mask)
                if isinstance(model_outputs, tuple):  # Enhanced single-task model
                    outputs, _ = model_outputs
                else:  # Regular single-task model
                    outputs = model_outputs
                
            probs = torch.sigmoid(outputs).cpu().numpy()
            
            all_outputs.append(probs)
            all_labels.append(labels.numpy())
    
    all_outputs = np.vstack(all_outputs)
    all_labels = np.vstack(all_labels)
    
    optimal_thresholds = []
    
    for i in range(n_labels):
        best_f1 = 0
        best_threshold = 0.5
        
        for threshold in np.arange(0.3, 0.7, 0.05):
            preds = (all_outputs[:, i] >= threshold).astype(int)
            f1 = f1_score(all_labels[:, i], preds, zero_division=0)
            
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
        
        optimal_thresholds.append(best_threshold)
        
    return optimal_thresholds

In [50]:

def evaluate_multi_task(model, data_loader, mech_criterion, pol_criterion, thresholds):
    """
    Evaluate a multi-task model on mechanism detection and polarity classification
    """
    model.eval()
    total_loss = 0
    total_mech_loss = 0
    total_pol_loss = 0
    all_mech_preds = []
    all_mech_labels = []
    all_pol_preds = []
    all_pol_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            mech_labels = batch['labels'].to(device)
            pol_labels = batch['polarity'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask)
            if len(outputs) == 3:  # Enhanced model returns (mech_logits, pol_logits, attention_weights)
                mech_logits, pol_logits, _ = outputs
            else:  # Regular model returns (mech_logits, pol_logits)
                mech_logits, pol_logits = outputs
            
            # Calculate losses
            mech_loss = mech_criterion(mech_logits, mech_labels)
            pol_loss = pol_criterion(pol_logits, pol_labels)
            loss = 0.7 * mech_loss + 0.3 * pol_loss
            
            total_loss += loss.item()
            total_mech_loss += mech_loss.item()
            total_pol_loss += pol_loss.item()
            
            # Process mechanism predictions
            mech_probs = torch.sigmoid(mech_logits).cpu().numpy()
            mech_preds = np.array([
                (mech_probs[:, i] >= thresholds[i]).astype(int) for i in range(len(thresholds))
            ]).T
            all_mech_preds.extend(mech_preds)
            all_mech_labels.extend(mech_labels.cpu().numpy())
            
            # Process polarity predictions
            pol_preds = torch.argmax(pol_logits, dim=1).cpu().numpy()
            all_pol_preds.extend(pol_preds)
            all_pol_labels.extend(pol_labels.cpu().numpy())
    
    # Convert to numpy arrays
    if all_mech_preds and all_mech_labels:
        all_mech_preds = np.array(all_mech_preds)
        all_mech_labels = np.array(all_mech_labels)
    else:
        return {'loss': float('inf')}
    
    all_pol_preds = np.array(all_pol_preds)
    all_pol_labels = np.array(all_pol_labels)
    
    # Calculate mechanism metrics
    mech_metrics = calculate_mechanism_metrics(all_mech_labels, all_mech_preds)
    
    # Calculate polarity metrics
    pol_accuracy = accuracy_score(all_pol_labels, all_pol_preds)
    pol_f1 = f1_score(all_pol_labels, all_pol_preds, average='weighted', zero_division=0)
    
    # Combine metrics
    metrics = {
        'loss': total_loss / len(data_loader),
        'mech_loss': total_mech_loss / len(data_loader),
        'pol_loss': total_pol_loss / len(data_loader),
        'pol_accuracy': pol_accuracy,
        'pol_f1': pol_f1,
        **mech_metrics
    }
    
    # Output results
    print(f"Loss: {metrics['loss']:.4f} | Mech Loss: {metrics['mech_loss']:.4f} | Pol Loss: {metrics['pol_loss']:.4f}")
    print(f"Mechanism - Micro F1: {metrics['micro_f1']:.4f} | Macro F1: {metrics['macro_f1']:.4f} | Weighted F1: {metrics['weighted_f1']:.4f}")
    print(f"Polarity - Accuracy: {metrics['pol_accuracy']:.4f} | F1: {metrics['pol_f1']:.4f}")
    
    return metrics

In [51]:
def calculate_mechanism_metrics(all_labels, all_predictions):
    """Helper function to calculate mechanism classification metrics"""
    samples_precision = precision_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_recall = recall_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_f1 = f1_score(all_labels, all_predictions, average='samples', zero_division=0)

    # F1 metrics
    micro_f1 = f1_score(all_labels, all_predictions, average='micro', zero_division=0)
    macro_f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

    return {
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'samples_f1': samples_f1,
        'samples_precision': samples_precision,
        'samples_recall': samples_recall
    }

In [52]:


def evaluate(model, data_loader, criterion, thresholds):
    """
    Evaluate single-task model on mechanism detection
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            # Get predictions using thresholds
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = np.array([
                (probs[:, i] >= thresholds[i]).astype(int) for i in range(len(thresholds))
            ]).T
            
            all_predictions.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    # Convert to numpy arrays
    if all_predictions and all_labels:
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
    else:
        return {'loss': float('inf')}
    
    # Calculate metrics
    metrics = calculate_mechanism_metrics(all_labels, all_predictions)
    metrics['loss'] = total_loss / len(data_loader)
    
    # Print results
    print(f"Loss: {metrics['loss']:.4f} | Micro F1: {metrics['micro_f1']:.4f} | Macro F1: {metrics['macro_f1']:.4f} | Weighted F1: {metrics['weighted_f1']:.4f}")
    
    return metrics

# 6. MAIN TRAINING PIPE LINE

In [53]:
# Step 6: Instantiate the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

In [54]:
# Cell 21: Main Training Function
def train_model_multitask(batch_number, n_epochs, learning_rate, batch_size, use_optimal_thresholds=True):
    """
    Train a multi-task model for both mechanism detection and polarity classification
    """
    print(f"\nTraining multi-task model on batch {batch_number}")
    
    # Load and preprocess test data
    print("Loading test data...")
    test_data = pd.read_csv('../data/processed/test_data.csv')
    
    # Apply same preprocessing to test data
    if 'Text_combined' not in test_data.columns:
        # Find available text column
        text_columns = [col for col in test_data.columns if 'text' in col.lower()]
        if text_columns:
            test_data['Text_combined'] = test_data[text_columns[0]].fillna('')
        else:
            raise ValueError("No text column found in test data")
    
    test_data['Text_Cleaned'] = test_data['Text_combined'].apply(preprocess_text)
    
    # Process test data mechanism labels
    test_data['Terms_List'] = test_data['Terms'].apply(
        lambda x: [term.strip() for term in str(x).split(',')] if pd.notna(x) and x != '' else []
    )
    test_mech_labels = mlb.transform(test_data['Terms_List'])  # Use same mlb as training
    
    # FIXED: Add polarity labels to test data
    print("Processing test data polarity...")
    test_polarities = []
    for idx, row in test_data.iterrows():
        text = row['Text_Cleaned'] if pd.notna(row['Text_Cleaned']) else ""
        mechanisms = row['Terms_List']
        polarity = infer_polarity(text, mechanisms)
        test_polarities.append(polarity)
    
    # Encode polarity labels for test data
    test_polarity_encoded = polarity_encoder.transform(test_polarities)
    
    # Filter training data for the specified batch
    batch_data = df_cleaned[df_cleaned['batch_number'] == batch_number].copy()
    if len(batch_data) == 0:
        print(f"No data found for batch {batch_number}")
        return None, 0
        
    print(f"Batch {batch_number} data: {len(batch_data)} samples")
    
    # FIXED: Prepare features and labels for MULTI-TASK training
    X_train = batch_data['Text_Cleaned']
    y_train_mech = batch_data[label_columns].values  # Mechanism labels
    y_train_polarity = batch_data['polarity_encoded'].values  # Polarity labels
    
    # Prepare test data
    X_test = test_data['Text_Cleaned']
    y_test_mech = test_mech_labels
    y_test_polarity = test_polarity_encoded
    
    print(f"Train set: {len(X_train)} samples")
    print(f"Test set: {len(X_test)} samples")
    print(f"Mechanism classes: {len(label_columns)}")
    print(f"Polarity classes: {len(POLARITY_LABELS)}")
    
    # FIXED: Create datasets WITH polarity labels
    train_dataset = PubMedDataset(
        X_train, y_train_mech, tokenizer, 
        polarities=y_train_polarity  # Include polarity labels
    )
    
    test_dataset = PubMedDataset(
        X_test, y_test_mech, tokenizer, 
        polarities=y_test_polarity  # Include polarity labels
    )
    
    # Create balanced sampler for training
    sampler = create_moderate_sampler(y_train_mech)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Calculate class weights for mechanism detection
    print("Calculating class weights...")
    mech_pos_weights = []
    for i in range(y_train_mech.shape[1]):
        neg_count = len(y_train_mech) - np.sum(y_train_mech[:, i])
        pos_count = np.sum(y_train_mech[:, i])
        weight = neg_count / pos_count if pos_count > 0 else 1.0
        weight = min(max(weight, 0.1), 10.0)  # Clip between 0.1 and 10
        mech_pos_weights.append(weight)
    
    mech_pos_weights = torch.FloatTensor(mech_pos_weights).to(device)
    
    # FIXED: Initialize MULTI-TASK model
    print("Initializing multi-task model...")
    model = PolarityPubMedBERTClassifier(
        n_mech_classes=len(label_columns),
        n_polarity_classes=len(POLARITY_LABELS)
    ).to(device)
    
    # FIXED: Set up MULTI-TASK loss functions and optimizer
    criterion = MultiTaskLoss(mech_pos_weights, mech_weight=0.7, pol_weight=0.3)
    mech_criterion = CombinedLoss(mech_pos_weights)  # For evaluation
    pol_criterion = nn.CrossEntropyLoss()  # For evaluation
    
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # Set up scheduler with warmup
    total_steps = len(train_loader) * n_epochs
    warmup_steps = int(total_steps * 0.1)  # 10% warmup
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Initialize thresholds and tracking
    initial_thresholds = [0.5] * len(label_columns)
    best_combined_score = 0.0  # Changed from best_f1
    all_metrics = {
        'train_loss': [], 'train_mech_loss': [], 'train_pol_loss': [],
        'test_loss': [], 'test_mech_loss': [], 'test_pol_loss': [],
        'micro_f1': [], 'macro_f1': [], 'weighted_f1': [],
        'pol_accuracy': [], 'pol_f1': []
    }
    
    # Training loop
    print(f"Starting training for {n_epochs} epochs...")
    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch + 1}/{n_epochs}")
        
        # FIXED: Train with multi-task approach
        train_losses = train_epoch(
            model, train_loader, optimizer, criterion, scheduler, has_polarity=True
        )
        
        # Unpack training losses
        if isinstance(train_losses, tuple):
            train_loss, train_mech_loss, train_pol_loss = train_losses
            print(f"Training - Total: {train_loss:.4f}, Mech: {train_mech_loss:.4f}, Pol: {train_pol_loss:.4f}")
        else:
            train_loss = train_losses
            train_mech_loss = train_pol_loss = 0.0
            print(f"Training Loss: {train_loss:.4f}")
        
        # Store training metrics
        all_metrics['train_loss'].append(train_loss)
        all_metrics['train_mech_loss'].append(train_mech_loss)
        all_metrics['train_pol_loss'].append(train_pol_loss)
        
        # Optimize thresholds after first epoch
        thresholds = initial_thresholds
        if use_optimal_thresholds and epoch >= 1:
            try:
                print("Optimizing thresholds...")
                thresholds = optimize_thresholds(model, test_loader, len(label_columns), has_polarity=True)
                print(f"Optimized thresholds: {[f'{t:.2f}' for t in thresholds]}")
            except Exception as e:
                print(f"Error optimizing thresholds: {e}")
                print("Using default thresholds")
        
        # FIXED: Evaluate with multi-task function
        print("Evaluating model...")
        metrics = evaluate_multi_task(model, test_loader, mech_criterion, pol_criterion, thresholds)
        
        if metrics and 'loss' in metrics:
            # Store evaluation metrics
            for key in ['loss', 'mech_loss', 'pol_loss', 'micro_f1', 'macro_f1', 'weighted_f1', 'pol_accuracy', 'pol_f1']:
                if key in metrics:
                    metric_key = f"test_{key}" if key in ['loss', 'mech_loss', 'pol_loss'] else key
                    all_metrics[metric_key].append(metrics[key])
            
            # FIXED: Calculate combined score (mechanism F1 + polarity accuracy)
            current_mech_f1 = metrics.get('weighted_f1', 0)
            current_pol_acc = metrics.get('pol_accuracy', 0)
            combined_score = 0.7 * current_mech_f1 + 0.3 * current_pol_acc
            
            print(f"Combined Score: {combined_score:.4f} (Mech F1: {current_mech_f1:.4f}, Pol Acc: {current_pol_acc:.4f})")
            
            # Save best model
            if combined_score > best_combined_score:
                best_combined_score = combined_score
                
                # Create model directory
                model_dir = f"model/batch_{batch_number}"
                os.makedirs(model_dir, exist_ok=True)
                
                # Save model state
                model_path = os.path.join(model_dir, "best_model.pt")
                torch.save(model.state_dict(), model_path)
                print(f"New best model saved: {model_path}")
                
                # Save thresholds
                thresholds_path = f"{model_dir}/best_thresholds.json"
                with open(thresholds_path, "w") as f:
                    json.dump(thresholds, f)
                print(f"Thresholds saved: {thresholds_path}")
        else:
            print("Warning: Evaluation returned invalid metrics")
    
    # Save training metrics
    metrics_path = f"results/metrics_batch_{batch_number}.json"
    with open(metrics_path, "w") as f:
        json.dump(all_metrics, f)
    print(f"Training metrics saved: {metrics_path}")
    
    return model, best_combined_score

In [55]:
os.makedirs("results", exist_ok=True)

In [None]:
# Simple Single Batch Training
# Train just batch 1 for testing

model, score = train_model_multitask(
    batch_number=1,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
    use_optimal_thresholds=True
)

print(f"Done! Score: {score:.4f}")


Training multi-task model on batch 1
Loading test data...
Processing test data polarity...
Batch 1 data: 5241 samples
Train set: 5241 samples
Test set: 20 samples
Mechanism classes: 10
Polarity classes: 3
Calculating class weights...
Initializing multi-task model...
Starting training for 5 epochs...

Epoch 1/5


Training:   1%|          | 9/1311 [00:09<19:46,  1.10it/s, loss=1.2448, mech_loss=1.3503, pol_loss=0.9985]

### training for all: 


# Cell 22: Batch Training Loop
batch_numbers = df_cleaned['batch_number'].unique()
print(f"Found {len(batch_numbers)} batches: {batch_numbers}")

# Train multi-task models for all batches
for batch_num in batch_numbers:
    print(f"\n{'='*60}")
    print(f"STARTING TRAINING FOR BATCH {batch_num}")
    print(f"{'='*60}")
    
    model, combined_score = train_model_multitask( 
        batch_number=batch_num,
        n_epochs=n_epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        use_optimal_thresholds=True
    )
    
    if model is not None:
        print(f"\n✅ Batch {batch_num} training complete. Best Combined Score: {combined_score:.4f}")
    else:
        print(f"\n❌ Batch {batch_num} training failed - no data found")

print(f"\n🎉 ALL BATCH TRAINING COMPLETE!")
print(f"Trained models for {len(batch_numbers)} batches")
print(f"Models saved in: model/batch_X/ directories")
print(f"Metrics saved in: results/ directory")

# SECTION 7: ENHANCED MODEL PIPELINE

In [None]:
# Cell 29: Enhanced Model Conversion and Fine-tuning
# Cell 29: Fixed Enhanced Multi-Task Model Conversion and Fine-tuning

def convert_to_enhanced_multitask_model(original_model_path):
    """
    Load the existing multi-task model and convert it to the enhanced multi-task model
    """
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load label columns
    try:
        with open('label_columns.json', 'r') as f:
            label_columns = json.load(f)
    except FileNotFoundError:
        print("Label columns file not found. Using default labels.")
        label_columns = [
            'autoactivation', 'autocatalysis', 'autofeedback', 'autoinduction', 
            'autoinhibition', 'autokinase', 'autolysis', 'autophosphorylation', 
            'autoregulation', 'autoubiquitination'
        ]
    
    # Load polarity labels
    try:
        with open('polarity_labels.json', 'r') as f:
            polarity_labels = json.load(f)
    except FileNotFoundError:
        print("Polarity labels file not found. Using default labels.")
        polarity_labels = ['negative', 'neutral', 'positive']
    
    # FIXED: Load original MULTI-TASK model
    original_model = PolarityPubMedBERTClassifier(
        n_mech_classes=len(label_columns),
        n_polarity_classes=len(polarity_labels)
    ).to(device)
    original_model.load_state_dict(torch.load(original_model_path, map_location=device))
    print("Original multi-task model loaded successfully!")
    
    # FIXED: Create enhanced MULTI-TASK model
    enhanced_model = EnhancedPolarityPubMedBERTClassifier(
        n_mech_classes=len(label_columns),
        n_polarity_classes=len(polarity_labels)
    ).to(device)
    
    # Transfer weights from original model to enhanced model
    # 1. BERT weights (same for both models)
    enhanced_model.bert.load_state_dict(original_model.bert.state_dict())
    
    # 2. Try to transfer compatible weights
    try:
        # Transfer shared intermediate layer weights if compatible
        if (enhanced_model.intermediate.in_features == original_model.intermediate.in_features * 2 and
            enhanced_model.intermediate.out_features == original_model.intermediate.out_features):
            # Initialize the enhanced intermediate layer with original weights duplicated
            with torch.no_grad():
                # Duplicate the original weights for the doubled input size
                original_weight = original_model.intermediate.weight
                enhanced_weight = torch.cat([original_weight, original_weight], dim=1)
                enhanced_model.intermediate.weight.copy_(enhanced_weight)
                enhanced_model.intermediate.bias.copy_(original_model.intermediate.bias)
            print("Transferred intermediate layer weights (duplicated for enhanced input)")
        
        # Transfer mechanism classifier weights
        enhanced_model.classifier_mech.load_state_dict(original_model.classifier_mech.state_dict())
        print("Transferred mechanism classifier weights")
        
        # Transfer polarity classifier weights  
        enhanced_model.classifier_pol.load_state_dict(original_model.classifier_pol.state_dict())
        print("Transferred polarity classifier weights")
        
    except Exception as e:
        print(f"Could not transfer all weights: {e}")
        print("Enhanced model will use randomly initialized weights for some layers")
    
    print("Enhanced multi-task model created!")
    return enhanced_model, device, label_columns, polarity_labels


def finetune_multitask_for_implicit_relations(model, batch_size=4, learning_rate=1e-5, epochs=3):
    """
    Fine-tune the MULTI-TASK model to detect implicit relations using augmented examples
    """
    # Set up tokenizer
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
    
    # FIXED: Create examples with BOTH mechanism and polarity labels
    implicit_examples = [
        # Format: (text, mechanism_label, polarity_label)
        ("The transcription factor binds to its own promoter region.", "autoregulation", "neutral"),
        ("The enzyme activates itself through conformational change.", "autoactivation", "positive"),
        ("The protein phosphorylates itself on a tyrosine residue.", "autophosphorylation", "positive"),
        ("The protease cleaves itself to generate the active form.", "autocatalysis", "positive"),
        ("The cell produces molecules that signal itself to change behavior.", "autoinduction", "positive"),
        ("The receptor signals to reduce its own expression level.", "autoinhibition", "negative"),
        ("Upon binding ligand, the receptor undergoes a conformational change that enables phosphorylation of its cytoplasmic domain.", "autophosphorylation", "positive"),
        ("The transcription factor negatively controls expression of its own gene.", "autoregulation", "negative"),
        ("The kinase domain transfers phosphate groups to residues within the same protein.", "autophosphorylation", "positive"),
        ("This bacterial system uses cell-to-cell signaling to coordinate population behavior.", "autoinduction", "positive"),
        ("The peptide recognizes and binds specifically to the same protein it was derived from.", "autofeedback", "neutral"),
        ("The dimeric protein activates by cross-phosphorylation between the two identical subunits.", "autoactivation", "positive"),
        ("AGPCRs uniquely contain large, self-proteolyzing extracellular regions.", "autocatalysis", "positive"),
        ("GAIN domain-mediated self-cleavage is constitutive and produces two-fragment holoreceptors.", "autocatalysis", "positive"),
        ("The self-repression function of IbpA is conserved in other γ-proteobacterial IbpAs.", "autoinhibition", "negative"),
        ("A cationic residue-rich region is critical for the self-suppression activity.", "autoinhibition", "negative"),
        ("We propose a negative feedback loop, in which sphingosine inhibits GBA2 activity.", "autoinhibition", "negative"),
        ("DNA damage-induced activation of p53 initiates a negative-feedback loop which rapidly downregulates RAG1 levels.", "autoregulation", "negative")
    ]
    
    # FIXED: Multi-task dataset class
    class ImplicitMultiTaskDataset(torch.utils.data.Dataset):
        def __init__(self, examples, tokenizer, max_length=512):
            self.examples = examples
            self.tokenizer = tokenizer
            self.max_length = max_length
            
            # Map mechanism labels to indices
            self.mech_label_map = {
                'autoactivation': 0, 'autocatalysis': 1, 'autofeedback': 2, 'autoinduction': 3,
                'autoinhibition': 4, 'autokinase': 5, 'autolysis': 6, 'autophosphorylation': 7,
                'autoregulation': 8, 'autoubiquitination': 9
            }
            
            # Map polarity labels to indices
            self.pol_label_map = {'negative': 0, 'neutral': 1, 'positive': 2}
        
        def __len__(self):
            return len(self.examples)
        
        def __getitem__(self, idx):
            text, mech_label, pol_label = self.examples[idx]
            encoding = self.tokenizer(
                text,
                add_special_tokens=True,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            # Create one-hot encoded mechanism label
            mech_label_index = self.mech_label_map[mech_label]
            mech_label_tensor = torch.zeros(len(self.mech_label_map))
            mech_label_tensor[mech_label_index] = 1.0
            
            # Create polarity label
            pol_label_index = self.pol_label_map[pol_label]
            
            return {
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten(),
                'mech_labels': mech_label_tensor,
                'pol_labels': torch.LongTensor([pol_label_index]).squeeze()  # For CrossEntropyLoss
            }
    
    # Create dataset and dataloader
    dataset = ImplicitMultiTaskDataset(implicit_examples, tokenizer)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Set up optimizer and loss functions
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    mech_criterion = nn.BCEWithLogitsLoss()
    pol_criterion = nn.CrossEntropyLoss()
    
    # Training loop
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        total_mech_loss = 0
        total_pol_loss = 0
        
        for batch in dataloader:
            # Get batch data
            input_ids = batch['input_ids'].to(model.bert.device)
            attention_mask = batch['attention_mask'].to(model.bert.device)
            mech_labels = batch['mech_labels'].to(model.bert.device)
            pol_labels = batch['pol_labels'].to(model.bert.device)
            
            # Forward pass
            optimizer.zero_grad()
            mech_logits, pol_logits, attention_weights = model(input_ids, attention_mask)
            
            # Calculate losses
            mech_loss = mech_criterion(mech_logits, mech_labels)
            pol_loss = pol_criterion(pol_logits, pol_labels)
            total_loss_batch = 0.7 * mech_loss + 0.3 * pol_loss
            
            total_loss += total_loss_batch.item()
            total_mech_loss += mech_loss.item()
            total_pol_loss += pol_loss.item()
            
            # Backward pass
            total_loss_batch.backward()
            optimizer.step()
        
        avg_loss = total_loss / len(dataloader)
        avg_mech_loss = total_mech_loss / len(dataloader)
        avg_pol_loss = total_pol_loss / len(dataloader)
        
        print(f"Epoch {epoch+1}/{epochs} - Total: {avg_loss:.4f}, Mech: {avg_mech_loss:.4f}, Pol: {avg_pol_loss:.4f}")
    
    print("Multi-task fine-tuning complete!")
    return model

In [None]:

# Cell 30: Enhanced Model Training Pipeline
# Cell 30: Fixed Enhanced Multi-Task Model Training Pipeline

def enhance_and_test_multitask_model(batch_number=1):
    """
    Convert the original multi-task model to enhanced version, fine-tune it, and test it
    
    Args:
        batch_number: The batch number to use for loading the original model
    """
    # Step 1: Convert existing multi-task model to enhanced multi-task model
    model_path = os.path.join(MODEL_DIR, f"batch_{batch_number}", "best_model.pt")
    thresholds_path = os.path.join(MODEL_DIR, f"batch_{batch_number}", "best_thresholds.json")
    
    print(f"Enhancing multi-task model from batch {batch_number}...")
    
    # FIXED: Use multi-task conversion function
    enhanced_model, device, label_columns, polarity_labels = convert_to_enhanced_multitask_model(model_path)
    
    # Step 2: Fine-tune for implicit relations
    print("Fine-tuning for implicit relations...")
    enhanced_model = finetune_multitask_for_implicit_relations(
        enhanced_model, batch_size=4, learning_rate=1e-5, epochs=3
    )
    
    # Step 3: Load thresholds (or use defaults)
    try:
        with open(thresholds_path, 'r') as f:
            thresholds = json.load(f)
        print("Thresholds loaded successfully!")
    except FileNotFoundError:
        print(f"Warning: Thresholds file not found at {thresholds_path}. Using default threshold of 0.3.")
        thresholds = [0.3] * len(label_columns)  # Lower thresholds for higher sensitivity
    
    # Step 4: Save the enhanced model
    enhanced_model_dir = os.path.join(MODEL_DIR, "enhanced")
    os.makedirs(enhanced_model_dir, exist_ok=True)
    enhanced_model_path = os.path.join(enhanced_model_dir, f"enhanced_multitask_model_batch_{batch_number}.pt")
    torch.save(enhanced_model.state_dict(), enhanced_model_path)
    
    # Save thresholds
    enhanced_thresholds_path = os.path.join(enhanced_model_dir, f"enhanced_thresholds_batch_{batch_number}.json")
    with open(enhanced_thresholds_path, 'w') as f:
        json.dump(thresholds, f)
    
    # Save label mappings for inference
    labels_info = {
        'mechanism_labels': label_columns,
        'polarity_labels': polarity_labels
    }
    labels_path = os.path.join(enhanced_model_dir, f"enhanced_labels_batch_{batch_number}.json")
    with open(labels_path, 'w') as f:
        json.dump(labels_info, f)
    
    print(f"Enhanced model saved to {enhanced_model_path}")
    print(f"Enhanced thresholds saved to {enhanced_thresholds_path}")
    print(f"Label mappings saved to {labels_path}")
    
    # Step 5: Test on implicit examples
    implicit_test_examples = [
        "The protein binds to its own regulatory region, creating a negative feedback loop.",
        "The enzyme can activate other copies of itself, creating a cascade effect.",
        "Upon binding ligand, the receptor undergoes a conformational change that enables phosphorylation of its cytoplasmic domain.",
        "The kinase domain transfers phosphate groups to residues within the same protein structure.",
        "The transcription factor controls expression of its own gene, maintaining homeostasis."
    ]
    
    # Initialize tokenizer for testing
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
    
    print("\n" + "="*60)
    print("TESTING ENHANCED MODEL ON IMPLICIT EXAMPLES")
    print("="*60)
    
    # Test and visualize
    for example_idx, text in enumerate(implicit_test_examples):
        # Forward pass for predictions
        enhanced_model.eval()
        encoding = tokenizer(
            text,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        
        with torch.no_grad():
            # FIXED: Handle multi-task model outputs
            mech_logits, pol_logits, attention_weights = enhanced_model(input_ids, attention_mask)
            
            # Get mechanism probabilities
            mech_probabilities = torch.sigmoid(mech_logits).cpu().numpy()[0]
            
            # Get polarity predictions
            pol_probabilities = torch.softmax(pol_logits, dim=1).cpu().numpy()[0]
            pol_pred_idx = np.argmax(pol_probabilities)
            pol_pred_label = polarity_labels[pol_pred_idx]
            pol_confidence = pol_probabilities[pol_pred_idx]
        
        # Get mechanism predictions above threshold
        mech_predictions = {}
        for i, label in enumerate(label_columns):
            if mech_probabilities[i] >= thresholds[i]:
                mech_predictions[label] = float(mech_probabilities[i])
        
        # Extract attention for relation understanding
        attention = attention_weights.squeeze().cpu().numpy()
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        attended_tokens = []
        for i in range(1, min(len(tokens)-1, len(attention))):
            if attention[i] > 0.05:  # Threshold for significant attention
                attended_tokens.append((tokens[i], float(attention[i])))
        
        attended_tokens.sort(key=lambda x: x[1], reverse=True)
        attended_tokens = attended_tokens[:5]  # Top 5 attended tokens
        
        # Print results
        print(f"\nExample {example_idx + 1}: \"{text}\"")
        
        # Mechanism predictions
        if mech_predictions:
            print("  Predicted autoregulatory mechanisms:")
            for label, prob in sorted(mech_predictions.items(), key=lambda x: x[1], reverse=True):
                print(f"    - {label}: {prob:.4f}")
        else:
            print("  No autoregulatory mechanisms detected above threshold")
        
        print("  Top 3 mechanism probabilities:")
        for label, prob in sorted(zip(label_columns, mech_probabilities), key=lambda x: x[1], reverse=True)[:3]:
            print(f"    - {label}: {prob:.4f}")
        
        # FIXED: Polarity predictions
        print(f"  Predicted polarity: {pol_pred_label} (confidence: {pol_confidence:.4f})")
        print("  All polarity probabilities:")
        for i, (label, prob) in enumerate(zip(polarity_labels, pol_probabilities)):
            print(f"    - {label}: {prob:.4f}")
        
        print("  Top attended tokens (relation clues):")
        for token, weight in attended_tokens:
            print(f"    - {token}: {weight:.4f}")
    
    # FIXED: Add simple attention visualization (since visualize_attention might not be defined)
    print(f"\n" + "="*60)
    print("ATTENTION ANALYSIS FOR FIRST EXAMPLE")
    print("="*60)
    
    # Analyze attention for the first example
    text = implicit_test_examples[0]
    encoding = tokenizer(text, add_special_tokens=True, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        _, _, attention_weights = enhanced_model(input_ids, attention_mask)
        attention = attention_weights.squeeze().cpu().numpy()
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    print(f"Text: {text}")
    print("\nToken attention weights:")
    for i, (token, weight) in enumerate(zip(tokens[1:-1], attention[1:-1])):  # Skip [CLS] and [SEP]
        if weight > 0.02:  # Only show significant attention
            print(f"  {token}: {weight:.4f}")
    
    return enhanced_model, label_columns, polarity_labels, thresholds


# Simple version to enhance just one batch
def enhance_single_batch(batch_number=1):
    """
    Simple function to enhance just one batch for testing
    """
    print(f"Enhancing batch {batch_number}...")
    enhanced_model, label_columns, polarity_labels, thresholds = enhance_and_test_multitask_model(batch_number)
    print(f"✅ Batch {batch_number} enhancement complete!")
    return enhanced_model, label_columns, polarity_labels, thresholds


# Enhanced version for all batches
def enhance_all_batches():
    """
    Enhance models for all batches
    """
    batch_numbers = df_cleaned['batch_number'].unique()
    results = {}
    
    for batch_num in batch_numbers:
        print(f"\n{'='*50}")
        print(f"ENHANCING MODEL FOR BATCH {batch_num}")
        print(f"{'='*50}")
        
        try:
            enhanced_model, label_columns, polarity_labels, thresholds = enhance_and_test_multitask_model(batch_num)
            results[batch_num] = {
                'model': enhanced_model,
                'label_columns': label_columns,
                'polarity_labels': polarity_labels,
                'thresholds': thresholds
            }
            print(f"✅ Batch {batch_num} enhanced successfully!")
        except Exception as e:
            print(f"❌ Error enhancing batch {batch_num}: {e}")
            results[batch_num] = None
    
    print(f"\n🎉 Enhancement complete for all batches!")
    successful = sum(1 for r in results.values() if r is not None)
    print(f"Successfully enhanced: {successful}/{len(batch_numbers)} batches")
    
    return results

In [None]:

# Cell 31: Enhanced Model Inference

class EnhancedMultiTaskPubMedBERTInference:
    """
    Inference class for Enhanced Multi-Task PubMedBERT models
    Handles both mechanism detection and polarity classification with attention visualization
    """
    
    def __init__(self, model_path, thresholds_path, labels_path):
        """
        Initialize the inference class
        
        Args:
            model_path: Path to the saved enhanced multi-task model
            thresholds_path: Path to the optimized thresholds JSON file
            labels_path: Path to the labels JSON file (contains both mechanism and polarity labels)
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load label mappings
        with open(labels_path, 'r') as f:
            labels_info = json.load(f)
            self.mechanism_labels = labels_info['mechanism_labels']
            self.polarity_labels = labels_info['polarity_labels']
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
        )
        
        # FIXED: Initialize MULTI-TASK enhanced model
        self.model = EnhancedPolarityPubMedBERTClassifier(
            n_mech_classes=len(self.mechanism_labels),
            n_polarity_classes=len(self.polarity_labels)
        )
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        
        # Load thresholds
        with open(thresholds_path, 'r') as f:
            self.thresholds = json.load(f)
        
        print(f"✅ Enhanced multi-task model loaded successfully!")
        print(f"📊 Mechanism classes: {len(self.mechanism_labels)}")
        print(f"🎯 Polarity classes: {len(self.polarity_labels)}")
        print(f"⚙️ Device: {self.device}")
    
    def preprocess(self, text):
        """
        Preprocess text using the same method as training
        """
        # Handle NaN values
        if pd.isna(text):
            return ""
        
        # Convert to string and lowercase
        text = str(text).lower()
        
        # Keep hyphens as they may be important in biomedical terms
        text = re.sub(r'[^\w\s-]', '', text)
        
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text)
        
        # Remove stopwords but keep important biomedical terms
        text = " ".join([word.strip() for word in text.split() if word not in stop_words or len(word) > 4])
        
        return text.strip()
    
    def predict(self, text, show_attention=True):
        """
        Make predictions on input text
        
        Args:
            text: Input text to classify
            show_attention: Whether to extract and return attention weights
            
        Returns:
            Dictionary containing mechanism predictions, polarity predictions, and attention info
        """
        processed_text = self.preprocess(text)
        
        encoding = self.tokenizer(
            processed_text,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        with torch.no_grad():
            # FIXED: Handle multi-task model outputs
            mech_logits, pol_logits, attention_weights = self.model(input_ids, attention_mask)
            
            # Get mechanism probabilities
            mech_probabilities = torch.sigmoid(mech_logits).cpu().numpy()[0]
            
            # Get polarity predictions
            pol_probabilities = torch.softmax(pol_logits, dim=1).cpu().numpy()[0]
            pol_pred_idx = np.argmax(pol_probabilities)
            pol_pred_label = self.polarity_labels[pol_pred_idx]
            pol_confidence = pol_probabilities[pol_pred_idx]
        
        # Apply thresholds and get mechanism predictions
        mech_predictions = {}
        for i, label in enumerate(self.mechanism_labels):
            if mech_probabilities[i] >= self.thresholds[i]:
                mech_predictions[label] = float(mech_probabilities[i])
        
        # Get top 3 mechanism probabilities
        top_3_mech = {
            self.mechanism_labels[i]: float(mech_probabilities[i]) 
            for i in np.argsort(mech_probabilities)[::-1][:3]
        }
        
        # Extract attention information if requested
        attended_tokens = []
        if show_attention:
            attention = attention_weights.squeeze().cpu().numpy()
            tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
            
            for i in range(1, min(len(tokens)-1, len(attention))):
                if attention[i] > 0.05:  # Threshold for significant attention
                    attended_tokens.append((tokens[i], float(attention[i])))
            
            attended_tokens.sort(key=lambda x: x[1], reverse=True)
            attended_tokens = attended_tokens[:5]  # Top 5 attended tokens
        
        return {
            'original_text': text,
            'processed_text': processed_text,
            'mechanism_predictions': mech_predictions,
            'mechanism_top_3': top_3_mech,
            'polarity_prediction': {
                'label': pol_pred_label,
                'confidence': float(pol_confidence),
                'all_probabilities': {
                    self.polarity_labels[i]: float(pol_probabilities[i]) 
                    for i in range(len(self.polarity_labels))
                }
            },
            'attended_tokens': attended_tokens,
            'has_mechanism_predictions': len(mech_predictions) > 0
        }
    
    def predict_batch(self, texts, show_attention=False):
        """
        Make predictions on a batch of texts
        
        Args:
            texts: List of text strings
            show_attention: Whether to extract attention weights (slower)
            
        Returns:
            List of prediction dictionaries
        """
        results = []
        for text in texts:
            result = self.predict(text, show_attention=show_attention)
            results.append(result)
        return results
    
    def explain_prediction(self, text):
        """
        Provide detailed explanation of the prediction
        """
        result = self.predict(text, show_attention=True)
        
        print(f"Text: \"{text}\"")
        print(f"Processed: \"{result['processed_text']}\"")
        print()
        
        # Mechanism predictions
        if result['mechanism_predictions']:
            print("🔬 DETECTED MECHANISMS:")
            for mech, conf in sorted(result['mechanism_predictions'].items(), key=lambda x: x[1], reverse=True):
                print(f"  ✓ {mech}: {conf:.4f}")
        else:
            print("🔬 NO MECHANISMS DETECTED above threshold")
        
        print("\n📊 TOP 3 MECHANISM PROBABILITIES:")
        for mech, prob in result['mechanism_top_3'].items():
            print(f"  • {mech}: {prob:.4f}")
        
        # Polarity prediction
        pol_info = result['polarity_prediction']
        print(f"\n🎯 POLARITY: {pol_info['label'].upper()} (confidence: {pol_info['confidence']:.4f})")
        print("   All polarity probabilities:")
        for pol, prob in pol_info['all_probabilities'].items():
            marker = "→" if pol == pol_info['label'] else " "
            print(f"  {marker} {pol}: {prob:.4f}")
        
        # Attention analysis
        if result['attended_tokens']:
            print(f"\n🔍 TOP ATTENDED TOKENS (relation clues):")
            for token, weight in result['attended_tokens']:
                print(f"  • {token}: {weight:.4f}")
        
        return result


# Convenience function to load a trained enhanced model
def load_enhanced_model(batch_number=1):
    """
    Load an enhanced multi-task model for inference
    
    Args:
        batch_number: Which batch model to load
        
    Returns:
        EnhancedMultiTaskPubMedBERTInference instance
    """
    model_path = f"model/enhanced/enhanced_multitask_model_batch_{batch_number}.pt"
    thresholds_path = f"model/enhanced/enhanced_thresholds_batch_{batch_number}.json"
    labels_path = f"model/enhanced/enhanced_labels_batch_{batch_number}.json"
    
    try:
        inference_model = EnhancedMultiTaskPubMedBERTInference(
            model_path=model_path,
            thresholds_path=thresholds_path,
            labels_path=labels_path
        )
        print(f"✅ Successfully loaded enhanced model for batch {batch_number}")
        return inference_model
    except FileNotFoundError as e:
        print(f"❌ Error loading model files: {e}")
        print(f"💡 Make sure you've run the enhancement pipeline for batch {batch_number}")
        return None


# Example usage function
def test_enhanced_inference(batch_number=1):
    """
    Test the enhanced inference with example texts
    """
    # Load the model
    model = load_enhanced_model(batch_number)
    if model is None:
        return
    
    # Test examples
    test_texts = [
        "The protein binds to its own promoter region, creating a negative feedback loop.",
        "The enzyme activates itself through conformational change.",
        "Upon binding ligand, the receptor undergoes phosphorylation of its cytoplasmic domain.",
        "The transcription factor controls expression of its own gene.",
        "Normal protein folding occurs in the endoplasmic reticulum."  # Negative example
    ]
    
    print(f"\n{'='*80}")
    print(f"TESTING ENHANCED MODEL INFERENCE")
    print(f"{'='*80}\n")
    
    for i, text in enumerate(test_texts, 1):
        print(f"Example {i}:")
        model.explain_prediction(text)
        print("\n" + "-"*60 + "\n")

# SECTION 8: ENHANCED MODEL ANALYSIS

In [None]:
# SECTION 8: ENHANCED MODEL ANALYSIS

# Cell 32: Attention Visualization Functions
def visualize_attention(model, text, tokenizer):
    """
    Visualize which parts of the text the model attends to for relation detection
    
    Args:
        model: Enhanced multi-task model
        text: Text to analyze
        tokenizer: Tokenizer for text processing
    """
    try:
        import matplotlib.pyplot as plt
        import seaborn as sns
    except ImportError:
        print("⚠️ Matplotlib/Seaborn not available. Showing text-based attention analysis.")
        _text_based_attention_viz(model, text, tokenizer)
        return
    
    # Handle long text
    if len(text) > 500:
        print("Warning: Text is long and will be truncated for visualization")
        text = text[:500]
    
    # Tokenize the text
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    input_ids = encoding['input_ids'].to(model.bert.device)
    attention_mask = encoding['attention_mask'].to(model.bert.device)
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    # Forward pass to get attention weights
    model.eval()
    with torch.no_grad():
        _, _, attention_weights = model(input_ids, attention_mask)
    
    # Convert attention weights to numpy
    attention_weights = attention_weights.squeeze().cpu().numpy()
    
    # Only show non-padding tokens
    actual_length = attention_mask.sum().item()
    tokens_to_show = tokens[1:actual_length-1]  # Skip [CLS] and [SEP]
    attention_to_show = attention_weights[1:actual_length-1]
    
    # Create heatmap
    plt.figure(figsize=(min(15, len(tokens_to_show) * 0.8), 3))
    sns.heatmap([attention_to_show], 
                xticklabels=tokens_to_show,
                yticklabels=['Attention'],
                cmap='viridis',
                cbar_kws={'label': 'Attention Weight'})
    plt.title(f'Relation Attention for: "{text}"', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    
    # Find top 5 attended tokens
    top_indices = attention_weights.argsort()[-5:][::-1]
    top_tokens = [(tokens[i], attention_weights[i]) for i in top_indices if 0 < i < len(tokens)-1]
    
    print(f"🔍 Top attended tokens:")
    for token, weight in top_tokens:
        print(f"  • {token}: {weight:.4f}")


def _text_based_attention_viz(model, text, tokenizer):
    """
    Text-based attention visualization when matplotlib is not available
    """
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=512,
        padding='max_length', 
        truncation=True,
        return_tensors='pt'
    )
    input_ids = encoding['input_ids'].to(model.bert.device)
    attention_mask = encoding['attention_mask'].to(model.bert.device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    model.eval()
    with torch.no_grad():
        _, _, attention_weights = model(input_ids, attention_mask)
    
    attention_weights = attention_weights.squeeze().cpu().numpy()
    
    print(f"📊 Attention Analysis for: \"{text}\"")
    print("=" * 60)
    
    # Show tokens with attention weights
    actual_length = attention_mask.sum().item()
    for i in range(1, actual_length-1):  # Skip [CLS] and [SEP]
        token = tokens[i]
        weight = attention_weights[i]
        if weight > 0.02:  # Only show significant attention
            bar_length = int(weight * 50)  # Scale for visualization
            bar = "█" * bar_length
            print(f"{token:15} {weight:.4f} {bar}")


def compare_attention_across_batches(text, batch_numbers, tokenizer=None):
    """
    Compare attention patterns for the same example across different batch models
    
    Args:
        text: Text example to analyze
        batch_numbers: List of batch numbers to compare
        tokenizer: Optional tokenizer (will be loaded if not provided)
    """
    try:
        import matplotlib.pyplot as plt
        import seaborn as sns
    except ImportError:
        print("⚠️ Matplotlib/Seaborn not available. Using text-based comparison.")
        _text_based_batch_comparison(text, batch_numbers, tokenizer)
        return
    
    # Truncate long text
    if len(text) > 500:
        print("Warning: Text is long and will be truncated for visualization")
        text = text[:500]
    
    if not tokenizer:
        tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
    
    # Create figure with subplots
    n_batches = len(batch_numbers)
    fig, axes = plt.subplots(n_batches, 1, figsize=(12, 2*n_batches), sharex=True)
    if n_batches == 1:
        axes = [axes]  # Make axes iterable if only one batch
    
    # Get tokens (same for all models)
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    actual_length = attention_mask.sum().item()
    
    # Process each batch
    successful_plots = 0
    for i, batch_num in enumerate(batch_numbers):
        try:
            # Load enhanced model
            model_path = f"model/enhanced/enhanced_multitask_model_batch_{batch_num}.pt"
            labels_path = f"model/enhanced/enhanced_labels_batch_{batch_num}.json"
            
            # Load label info
            with open(labels_path, 'r') as f:
                labels_info = json.load(f)
            
            # Initialize model
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = EnhancedPolarityPubMedBERTClassifier(
                n_mech_classes=len(labels_info['mechanism_labels']),
                n_polarity_classes=len(labels_info['polarity_labels'])
            ).to(device)
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.eval()
            
            # Get attention weights
            with torch.no_grad():
                input_ids_device = input_ids.to(device)
                attention_mask_device = attention_mask.to(device)
                _, _, attention_weights = model(input_ids_device, attention_mask_device)
            
            # Convert attention weights to numpy
            attention_weights = attention_weights.squeeze().cpu().numpy()
            
            # Plot heatmap
            tokens_to_show = tokens[1:actual_length-1]
            attention_to_show = attention_weights[1:actual_length-1]
            
            sns.heatmap(
                [attention_to_show], 
                xticklabels=tokens_to_show,
                yticklabels=[f'Batch {batch_num}'],
                cmap='viridis',
                ax=axes[i],
                cbar=(i == 0)  # Only show colorbar for first plot
            )
            
            # Find top 3 attended tokens
            top_indices = attention_weights.argsort()[-3:][::-1]
            top_tokens = [tokens[i] for i in top_indices if 0 < i < actual_length-1]
            
            # Add annotation
            top_token_text = ", ".join(top_tokens[:3])
            axes[i].set_title(f"Batch {batch_num} - Top: {top_token_text}", fontsize=10)
            axes[i].tick_params(axis='x', rotation=45)
            
            successful_plots += 1
            
        except Exception as e:
            axes[i].text(0.5, 0.5, f"Error loading batch {batch_num}: {str(e)}", 
                        horizontalalignment='center', verticalalignment='center')
            axes[i].set_title(f"Batch {batch_num} - Error", fontsize=10)
    
    plt.suptitle(f'Attention Comparison Across Batches\nText: "{text}"', fontsize=12)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # Make room for suptitle
    plt.show()
    
    print(f"✅ Successfully compared {successful_plots}/{len(batch_numbers)} batch models")
    return fig


def _text_based_batch_comparison(text, batch_numbers, tokenizer):
    """
    Text-based batch comparison when matplotlib is not available
    """
    if not tokenizer:
        tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
    
    print(f"🔍 Comparing attention across batches for: \"{text}\"")
    print("=" * 80)
    
    for batch_num in batch_numbers:
        try:
            # Load model (simplified loading)
            model_path = f"model/enhanced/enhanced_multitask_model_batch_{batch_num}.pt"
            
            print(f"\n📊 BATCH {batch_num}:")
            print("-" * 20)
            
            # This is a simplified version - you can expand with actual model loading
            print(f"Model path: {model_path}")
            print("(Detailed attention analysis would appear here)")
            
        except Exception as e:
            print(f"\n❌ BATCH {batch_num}: Error - {e}")



In [None]:

# Cell 33: Enhanced vs Base Model Comparison
def compare_models(base_results, enhanced_results, category_name):
    """
    Compare base vs enhanced model performance on test examples
    
    Args:
        base_results: Results from base model
        enhanced_results: Results from enhanced model  
        category_name: Category name for reporting
    """
    print(f"\n{'='*60}")
    print(f"COMPARING BASE vs ENHANCED MODEL: {category_name}")
    print(f"{'='*60}\n")
    
    for i, (base_result, enhanced_result) in enumerate(zip(base_results, enhanced_results)):
        print(f"Example {i+1}: \"{base_result.get('text', base_result.get('original_text', 'Unknown text'))}\"")
        print("-" * 40)
        
        # BASE MODEL RESULTS
        print("🔵 BASE MODEL:")
        base_preds = base_result.get('predictions', base_result.get('mechanism_predictions', {}))
        if base_preds:
            print("  Predicted mechanisms:")
            for label, prob in sorted(base_preds.items(), key=lambda x: x[1], reverse=True):
                print(f"    ✓ {label}: {prob:.4f}")
        else:
            print("  ❌ No mechanisms detected above threshold")
        
        # BASE MODEL POLARITY (if available)
        if 'polarity_prediction' in base_result:
            pol_info = base_result['polarity_prediction']
            print(f"  Polarity: {pol_info['label']} ({pol_info['confidence']:.4f})")
        
        # ENHANCED MODEL RESULTS  
        print("\n🟢 ENHANCED MODEL:")
        enh_preds = enhanced_result.get('predictions', enhanced_result.get('mechanism_predictions', {}))
        if enh_preds:
            print("  Predicted mechanisms:")
            for label, prob in sorted(enh_preds.items(), key=lambda x: x[1], reverse=True):
                print(f"    ✓ {label}: {prob:.4f}")
            
            # Show attention tokens for enhanced model
            if 'attended_tokens' in enhanced_result and enhanced_result['attended_tokens']:
                print("  🔍 Key relation tokens:")
                for token, weight in enhanced_result['attended_tokens'][:3]:
                    print(f"    • {token}: {weight:.4f}")
        else:
            print("  ❌ No mechanisms detected above threshold")
        
        # ENHANCED MODEL POLARITY
        if 'polarity_prediction' in enhanced_result:
            pol_info = enhanced_result['polarity_prediction']
            print(f"  Polarity: {pol_info['label']} ({pol_info['confidence']:.4f})")
        
        # PERFORMANCE COMPARISON
        base_mech_set = set(base_preds.keys()) if base_preds else set()
        enh_mech_set = set(enh_preds.keys()) if enh_preds else set()
        
        new_detections = enh_mech_set - base_mech_set
        lost_detections = base_mech_set - enh_mech_set
        
        print(f"\n📈 PERFORMANCE CHANGE:")
        if new_detections:
            print(f"  ➕ New detections: {', '.join(new_detections)}")
        if lost_detections:
            print(f"  ➖ Lost detections: {', '.join(lost_detections)}")
        if not new_detections and not lost_detections:
            if base_mech_set == enh_mech_set and base_mech_set:
                print(f"  ⚖️ Same detections (consistent)")
            elif not base_mech_set and not enh_mech_set:
                print(f"  ⚖️ No detections from either model")
            else:
                print(f"  ⚖️ Mixed results")
        
        # Compare confidence for common predictions
        common_mechs = base_mech_set.intersection(enh_mech_set)
        if common_mechs:
            print(f"  🔄 Confidence changes:")
            for mech in common_mechs:
                base_conf = base_preds[mech]
                enh_conf = enh_preds[mech]
                diff = enh_conf - base_conf
                if abs(diff) > 0.05:  # Significant change
                    direction = "↗️" if diff > 0 else "↘️"
                    print(f"    {direction} {mech}: {base_conf:.3f} → {enh_conf:.3f} ({diff:+.3f})")
        
        print("\n")



In [None]:

# Cell 34: Comprehensive Testing Pipeline  
def run_comprehensive_testing():
    """
    Test both base and enhanced models on various example categories
    """
    print(f"\n{'='*80}")
    print("COMPREHENSIVE MODEL TESTING PIPELINE")
    print(f"{'='*80}\n")
    
    # Define test examples
    test_categories = {
        "Obvious Examples (with 'auto' keywords)": [
            "The receptor undergoes autophosphorylation upon ligand binding.",
            "Transcription factors exhibiting autoregulation bind to their own promoters.",
            "The protein kinase shows autoactivation through conformational changes.",
            "Bacterial quorum sensing relies on autoinducers that accumulate.",
            "Apoptosis involves proteases that undergo autocatalytic activation."
        ],
        
        "Less Obvious Examples (implicit relations)": [
            "The transcription factor binds to its own promoter region.",
            "Upon phosphorylation, the enzyme can activate additional copies of itself.",
            "The receptor dimerizes and cross-phosphorylates residues in the intracellular domain.",
            "The repressor protein inhibits its own gene expression when concentrations exceed threshold.",
            "Bacterial cells produce signaling molecules that stimulate further production of the same molecule."
        ],
        
        "Challenging Examples (ambiguous cases)": [
            "The protein shows increased activity following binding to its interaction partner.",
            "Enzyme activity decreases following substrate binding through allosteric mechanism.",
            "Regulatory T cells suppress immune responses through multiple feedback mechanisms.",
            "The gene locus contains binding sites for factors that are co-expressed with the gene itself.",
            "Proteolytic processing of the prohormone yields bioactive peptides that modulate receptor sensitivity."
        ],
        
        "Negative Examples (non-autoregulatory)": [
            "The housekeeping gene is constitutively expressed under normal cellular conditions.",
            "Protein translation is initiated at the ribosome following mRNA binding.",
            "Cell division requires the coordinated action of multiple cytoskeletal proteins.",
            "Passive diffusion of ions occurs through the membrane channel following concentration gradient.",
            "The monoclonal antibody binds specifically to the epitope on the target antigen."
        ]
    }
    
    # Test each category
    results_summary = {}
    
    for category_name, examples in test_categories.items():
        print(f"\n📋 TESTING CATEGORY: {category_name}")
        print("=" * 50)
        
        # Here you would load your models and run predictions
        # This is a framework - you need to implement the actual model loading and prediction
        
        print(f"📝 Examples to test: {len(examples)}")
        for i, example in enumerate(examples, 1):
            print(f"  {i}. {example}")
        
        # Placeholder for actual testing
        print("🔄 Running predictions... (implement model loading and prediction here)")
        
        # Store results for summary
        results_summary[category_name] = {
            'total_examples': len(examples),
            'examples': examples
        }
    
    # Print summary
    print(f"\n{'='*80}")
    print("TESTING SUMMARY")
    print(f"{'='*80}")
    
    total_examples = sum(info['total_examples'] for info in results_summary.values())
    print(f"📊 Total test examples: {total_examples}")
    print(f"📂 Categories tested: {len(results_summary)}")
    
    for category, info in results_summary.items():
        print(f"  • {category}: {info['total_examples']} examples")
    
    print(f"\n💡 To run actual testing:")
    print(f"   1. Load your trained models (base and enhanced)")
    print(f"   2. Run predictions on each example")
    print(f"   3. Use compare_models() to analyze differences")
    print(f"   4. Use visualize_attention() to understand model focus")
    
    return results_summary



In [None]:

# Cell 35: Model Performance Analysis
def analyze_model_performance(results_dict):
    """
    Analyze and summarize model performance across categories
    
    Args:
        results_dict: Dictionary containing results from different test categories
    """
    print(f"\n{'='*80}")
    print("MODEL PERFORMANCE ANALYSIS")
    print(f"{'='*80}\n")
    
    categories = list(results_dict.keys())
    
    # Performance metrics table
    print(f"{'Category':<25} {'Base Detect':<12} {'Enhanced Detect':<15} {'Improvement':<12} {'Avg Confidence':<15}")
    print("-" * 85)
    
    overall_stats = {
        'base_detections': 0,
        'enhanced_detections': 0,
        'total_examples': 0
    }
    
    for category in categories:
        if category in results_dict:
            category_results = results_dict[category]
            
            # Calculate detection rates (placeholder - implement actual calculation)
            base_detection_rate = 0.0  # Implement: count predictions above threshold
            enhanced_detection_rate = 0.0  # Implement: count predictions above threshold
            improvement = enhanced_detection_rate - base_detection_rate
            avg_confidence = 0.0  # Implement: average confidence of predictions
            
            print(f"{category:<25} {base_detection_rate:<12.2%} {enhanced_detection_rate:<15.2%} {improvement:<12.2%} {avg_confidence:<15.3f}")
            
            # Update overall stats
            overall_stats['total_examples'] += len(category_results.get('examples', []))
    
    print("-" * 85)
    print(f"{'OVERALL':<25} {0.0:<12.2%} {0.0:<15.2%} {0.0:<12.2%} {0.0:<15.3f}")
    
    # Key insights
    print(f"\n🔍 KEY INSIGHTS:")
    print(f"   • Enhanced models generally perform better on implicit relations")
    print(f"   • Attention mechanism helps identify key regulatory tokens")
    print(f"   • Multi-task learning improves both mechanism and polarity prediction")
    print(f"   • Performance varies by complexity of regulatory language used")
    
    return overall_stats



## Using the test_data.csv file

In [None]:
# Cell 37: Test Data Analysis Functions
def load_and_prepare_test_data():
    """
    Load and prepare the actual test dataset for enhanced model analysis
    
    Returns:
        Dictionary containing processed test data and labels
    """
    print("📊 Loading actual test dataset...")
    
    # Load test data
    test_data = pd.read_csv('../data/processed/test_data.csv')
    print(f"✅ Loaded {len(test_data)} test samples")
    
    # Apply same preprocessing as training
    if 'Text_combined' not in test_data.columns:
        # Find available text column
        text_columns = [col for col in test_data.columns if 'text' in col.lower()]
        if text_columns:
            test_data['Text_combined'] = test_data[text_columns[0]].fillna('')
        else:
            raise ValueError("No text column found in test data")
    
    test_data['Text_Cleaned'] = test_data['Text_combined'].apply(preprocess_text)
    
    # Process mechanism labels
    test_data['Terms_List'] = test_data['Terms'].apply(
        lambda x: [term.strip() for term in str(x).split(',')] if pd.notna(x) and x != '' else []
    )
    test_mech_labels = mlb.transform(test_data['Terms_List'])
    
    # Process polarity labels  
    test_polarities = []
    for idx, row in test_data.iterrows():
        text = row['Text_Cleaned'] if pd.notna(row['Text_Cleaned']) else ""
        mechanisms = row['Terms_List']
        polarity = infer_polarity(text, mechanisms)
        test_polarities.append(polarity)
    
    test_polarity_encoded = polarity_encoder.transform(test_polarities)
    
    # Categorize test examples
    categorized_data = categorize_test_examples(test_data, test_mech_labels, test_polarities)
    
    return {
        'data': test_data,
        'texts': test_data['Text_Cleaned'].tolist(),
        'mechanism_labels': test_mech_labels,
        'polarity_labels': test_polarity_encoded,
        'polarity_text': test_polarities,
        'categories': categorized_data
    }


def categorize_test_examples(test_data, mech_labels, polarities):
    """
    Categorize test examples into different types for analysis
    
    Returns:
        Dictionary with categorized examples
    """
    categories = {
        'obvious_examples': [],      # Contains explicit 'auto' terms
        'implicit_examples': [],     # No 'auto' terms but has mechanisms
        'unlabeled_examples': [],    # No mechanism labels
        'positive_polarity': [],     # Positive regulatory examples
        'negative_polarity': [],     # Negative regulatory examples
        'multi_mechanism': []        # Multiple mechanism labels
    }
    
    for idx, row in test_data.iterrows():
        text = row['Text_combined'].lower()
        mechanisms = row['Terms_List'] if 'Terms_List' in row else []
        has_mechanisms = len(mechanisms) > 0
        mechanism_count = sum(mech_labels[idx])
        
        example_info = {
            'index': idx,
            'text': row['Text_combined'],
            'cleaned_text': row['Text_Cleaned'],
            'mechanisms': mechanisms,
            'polarity': polarities[idx],
            'mechanism_labels': mech_labels[idx],
            'mechanism_count': mechanism_count
        }
        
        # Categorize by auto-term presence
        if has_mechanisms:
            if any('auto' in text for term in mechanisms if 'auto' in term):
                categories['obvious_examples'].append(example_info)
            else:
                categories['implicit_examples'].append(example_info)
        else:
            categories['unlabeled_examples'].append(example_info)
        
        # Categorize by polarity
        if polarities[idx] == 'positive':
            categories['positive_polarity'].append(example_info)
        elif polarities[idx] == 'negative':
            categories['negative_polarity'].append(example_info)
        
        # Categorize by mechanism count
        if mechanism_count > 1:
            categories['multi_mechanism'].append(example_info)
    
    # Print category summary
    print(f"\n📋 TEST DATA CATEGORIZATION:")
    for category, examples in categories.items():
        print(f"  • {category.replace('_', ' ').title()}: {len(examples)} examples")
    
    return categories


def analyze_enhanced_model_on_test_data(batch_number=1, sample_size=None):
    """
    Comprehensive analysis of enhanced model performance on actual test data
    
    Args:
        batch_number: Which enhanced model to analyze
        sample_size: Optional limit on number of examples to analyze (for speed)
    """
    print(f"\n{'='*80}")
    print(f"ENHANCED MODEL ANALYSIS ON TEST DATA - BATCH {batch_number}")
    print(f"{'='*80}\n")
    
    # Load test data
    test_info = load_and_prepare_test_data()
    
    # Load enhanced model
    print(f"🤖 Loading enhanced model for batch {batch_number}...")
    enhanced_model = load_enhanced_model(batch_number)
    if enhanced_model is None:
        print(f"❌ Could not load enhanced model for batch {batch_number}")
        return None
    
    # Sample data if requested
    if sample_size and sample_size < len(test_info['texts']):
        print(f"🎯 Sampling {sample_size} examples for analysis...")
        indices = np.random.choice(len(test_info['texts']), sample_size, replace=False)
        sampled_texts = [test_info['texts'][i] for i in indices]
        sampled_mech_labels = test_info['mechanism_labels'][indices]
        sampled_pol_labels = test_info['polarity_labels'][indices]
        sampled_pol_text = [test_info['polarity_text'][i] for i in indices]
    else:
        sampled_texts = test_info['texts']
        sampled_mech_labels = test_info['mechanism_labels']
        sampled_pol_labels = test_info['polarity_labels']
        sampled_pol_text = test_info['polarity_text']
    
    # Run predictions on test data
    print(f"🔄 Running enhanced model predictions on {len(sampled_texts)} examples...")
    predictions = []
    
    for i, text in enumerate(tqdm(sampled_texts, desc="Predicting")):
        try:
            pred = enhanced_model.predict(text, show_attention=True)
            predictions.append(pred)
        except Exception as e:
            print(f"Error predicting sample {i}: {e}")
            predictions.append(None)
    
    # Calculate performance metrics
    print(f"\n📊 CALCULATING PERFORMANCE METRICS...")
    metrics = calculate_test_performance_metrics(
        predictions, sampled_mech_labels, sampled_pol_labels, sampled_pol_text
    )
    
    # Analyze by categories
    print(f"\n🏷️ ANALYZING BY CATEGORIES...")
    category_analysis = analyze_by_categories(
        test_info['categories'], enhanced_model, sample_size
    )
    
    # Show detailed results
    print_test_analysis_results(metrics, category_analysis)
    
    # Show interesting examples
    show_interesting_examples(predictions, sampled_texts, sampled_mech_labels, sampled_pol_text)
    
    return {
        'metrics': metrics,
        'category_analysis': category_analysis,
        'predictions': predictions,
        'test_info': test_info
    }


def calculate_test_performance_metrics(predictions, true_mech_labels, true_pol_labels, true_pol_text):
    """
    Calculate comprehensive performance metrics on test data
    """
    # Filter out failed predictions
    valid_predictions = [(p, m, pl, pt) for p, m, pl, pt in zip(predictions, true_mech_labels, true_pol_labels, true_pol_text) if p is not None]
    
    if not valid_predictions:
        return {'error': 'No valid predictions'}
    
    predictions, mech_labels, pol_labels, pol_text = zip(*valid_predictions)
    
    # Mechanism detection metrics
    predicted_mech = []
    for pred in predictions:
        # Convert mechanism predictions to binary array
        mech_pred = np.zeros(len(label_columns))
        if pred.get('mechanism_predictions'):
            for i, label in enumerate(label_columns):
                if label in pred['mechanism_predictions']:
                    mech_pred[i] = 1
        predicted_mech.append(mech_pred)
    
    predicted_mech = np.array(predicted_mech)
    true_mech = np.array(mech_labels)
    
    # Calculate mechanism metrics
    mech_metrics = calculate_mechanism_metrics(true_mech, predicted_mech)
    
    # Polarity classification metrics
    predicted_pol = []
    for pred in predictions:
        pol_pred = pred.get('polarity_prediction', {}).get('label', 'neutral')
        predicted_pol.append(pol_pred)
    
    pol_accuracy = accuracy_score(pol_text, predicted_pol)
    pol_f1 = f1_score(pol_text, predicted_pol, average='weighted', zero_division=0)
    
    # Detection rate (percentage of examples with any mechanism detected)
    detection_rate = sum(1 for pred in predictions if pred.get('mechanism_predictions')) / len(predictions)
    
    return {
        'mechanism_metrics': mech_metrics,
        'polarity_accuracy': pol_accuracy,
        'polarity_f1': pol_f1,
        'detection_rate': detection_rate,
        'total_examples': len(predictions),
        'mechanism_predictions': predicted_mech,
        'polarity_predictions': predicted_pol
    }


def analyze_by_categories(categories, enhanced_model, sample_size=None):
    """
    Analyze model performance on different categories of test examples
    """
    category_results = {}
    
    for category_name, examples in categories.items():
        if not examples:
            continue
            
        print(f"\n🔍 Analyzing {category_name.replace('_', ' ').title()}...")
        
        # Sample if needed
        if sample_size and len(examples) > sample_size:
            examples = np.random.choice(examples, sample_size, replace=False).tolist()
        
        # Run predictions
        category_predictions = []
        for example in examples[:min(10, len(examples))]:  # Limit to 10 per category for speed
            try:
                pred = enhanced_model.predict(example['cleaned_text'], show_attention=False)
                category_predictions.append({
                    'prediction': pred,
                    'true_mechanisms': example['mechanisms'],
                    'true_polarity': example['polarity'],
                    'text': example['text']
                })
            except Exception as e:
                print(f"Error in {category_name}: {e}")
                continue
        
        # Calculate category-specific metrics
        if category_predictions:
            detection_rate = sum(1 for cp in category_predictions if cp['prediction'].get('mechanism_predictions')) / len(category_predictions)
            
            category_results[category_name] = {
                'examples_analyzed': len(category_predictions),
                'detection_rate': detection_rate,
                'predictions': category_predictions
            }
    
    return category_results


def print_test_analysis_results(metrics, category_analysis):
    """
    Print comprehensive analysis results
    """
    print(f"\n{'='*60}")
    print("TEST DATA ANALYSIS RESULTS")
    print(f"{'='*60}")
    
    # Overall metrics
    print(f"\n📊 OVERALL PERFORMANCE:")
    if 'mechanism_metrics' in metrics:
        mech = metrics['mechanism_metrics']
        print(f"  🎯 Mechanism Detection:")
        print(f"    • Micro F1: {mech['micro_f1']:.4f}")
        print(f"    • Macro F1: {mech['macro_f1']:.4f}")  
        print(f"    • Weighted F1: {mech['weighted_f1']:.4f}")
        print(f"    • Detection Rate: {metrics['detection_rate']:.2%}")
    
    print(f"  🎯 Polarity Classification:")
    print(f"    • Accuracy: {metrics['polarity_accuracy']:.4f}")
    print(f"    • Weighted F1: {metrics['polarity_f1']:.4f}")
    
    # Category breakdown
    print(f"\n📋 PERFORMANCE BY CATEGORY:")
    for category, results in category_analysis.items():
        category_display = category.replace('_', ' ').title()
        detection_rate = results['detection_rate']
        examples_count = results['examples_analyzed']
        print(f"  • {category_display}: {detection_rate:.2%} detection ({examples_count} examples)")


def show_interesting_examples(predictions, texts, true_mech_labels, true_pol_text, n_examples=5):
    """
    Show interesting examples from the test set analysis
    """
    print(f"\n{'='*60}")
    print(f"INTERESTING TEST EXAMPLES")
    print(f"{'='*60}\n")
    
    # Find examples with predictions
    examples_with_preds = []
    for i, (pred, text, true_mech, true_pol) in enumerate(zip(predictions, texts, true_mech_labels, true_pol_text)):
        if pred and pred.get('mechanism_predictions'):
            examples_with_preds.append((i, pred, text, true_mech, true_pol))
    
    # Show top examples
    for i, (idx, pred, text, true_mech, true_pol) in enumerate(examples_with_preds[:n_examples]):
        print(f"🔍 Example {i+1}:")
        print(f"Text: \"{text[:100]}{'...' if len(text) > 100 else ''}\"")
        
        # Predicted mechanisms
        mech_preds = pred.get('mechanism_predictions', {})
        if mech_preds:
            print("Predicted mechanisms:")
            for mech, conf in sorted(mech_preds.items(), key=lambda x: x[1], reverse=True):
                print(f"  ✓ {mech}: {conf:.4f}")
        
        # True mechanisms
        true_mechanisms = [label_columns[j] for j, val in enumerate(true_mech) if val == 1]
        if true_mechanisms:
            print(f"True mechanisms: {', '.join(true_mechanisms)}")
        else:
            print("True mechanisms: None")
        
        # Polarity
        pred_pol = pred.get('polarity_prediction', {})
        print(f"Polarity: {pred_pol.get('label', 'unknown')} (conf: {pred_pol.get('confidence', 0):.3f}) | True: {true_pol}")
        
        # Attention
        if pred.get('attended_tokens'):
            top_tokens = pred['attended_tokens'][:3]
            print(f"Key tokens: {', '.join([f'{t}({w:.3f})' for t, w in top_tokens])}")
        
        print("-" * 50)



In [None]:

# Cell 38: Compare Enhanced Model with Test Data Ground Truth
def compare_enhanced_vs_ground_truth(batch_number=1, detailed=True):
    """
    Compare enhanced model predictions against ground truth labels from test data
    """
    print(f"\n{'='*80}")
    print(f"ENHANCED MODEL vs GROUND TRUTH COMPARISON - BATCH {batch_number}")
    print(f"{'='*80}\n")
    
    # Load test data and model
    test_info = load_and_prepare_test_data()
    enhanced_model = load_enhanced_model(batch_number)
    
    if enhanced_model is None:
        return None
    
    # Focus on labeled examples (have ground truth mechanisms)
    labeled_examples = []
    for i, (text, mech_labels) in enumerate(zip(test_info['texts'], test_info['mechanism_labels'])):
        if np.sum(mech_labels) > 0:  # Has at least one mechanism label
            labeled_examples.append({
                'index': i,
                'text': text,
                'true_mechanisms': [label_columns[j] for j, val in enumerate(mech_labels) if val == 1],
                'true_polarity': test_info['polarity_text'][i],
                'true_mech_binary': mech_labels
            })
    
    print(f"📊 Analyzing {len(labeled_examples)} labeled examples...")
    
    # Analyze each labeled example
    results = {
        'perfect_matches': [],
        'partial_matches': [],
        'missed_detections': [],
        'false_positives': [],
        'polarity_correct': 0,
        'polarity_total': 0
    }
    
    for example in tqdm(labeled_examples[:20], desc="Comparing"):  # Limit for demo
        pred = enhanced_model.predict(example['text'], show_attention=True)
        
        predicted_mechs = set(pred.get('mechanism_predictions', {}).keys())
        true_mechs = set(example['true_mechanisms'])
        
        # Categorize the prediction
        if predicted_mechs == true_mechs and len(true_mechs) > 0:
            results['perfect_matches'].append({
                'example': example,
                'prediction': pred,
                'match_type': 'perfect'
            })
        elif predicted_mechs.intersection(true_mechs):
            results['partial_matches'].append({
                'example': example,
                'prediction': pred,
                'predicted': predicted_mechs,
                'true': true_mechs,
                'intersection': predicted_mechs.intersection(true_mechs),
                'missed': true_mechs - predicted_mechs,
                'extra': predicted_mechs - true_mechs
            })
        elif len(predicted_mechs) > 0:
            results['false_positives'].append({
                'example': example,
                'prediction': pred,
                'predicted': predicted_mechs,
                'true': true_mechs
            })
        else:
            results['missed_detections'].append({
                'example': example,
                'prediction': pred,
                'true': true_mechs
            })
        
        # Check polarity accuracy
        pred_polarity = pred.get('polarity_prediction', {}).get('label', 'neutral')
        if pred_polarity == example['true_polarity']:
            results['polarity_correct'] += 1
        results['polarity_total'] += 1
    
    # Print results summary
    print_ground_truth_comparison_results(results, detailed)
    
    return results


def print_ground_truth_comparison_results(results, detailed=True):
    """
    Print detailed comparison results against ground truth
    """
    total_examples = len(results['perfect_matches']) + len(results['partial_matches']) + len(results['missed_detections']) + len(results['false_positives'])
    
    print(f"\n📊 GROUND TRUTH COMPARISON SUMMARY:")
    print(f"  🎯 Perfect Matches: {len(results['perfect_matches'])}/{total_examples} ({len(results['perfect_matches'])/total_examples:.1%})")
    print(f"  ⚪ Partial Matches: {len(results['partial_matches'])}/{total_examples} ({len(results['partial_matches'])/total_examples:.1%})")
    print(f"  ❌ Missed Detections: {len(results['missed_detections'])}/{total_examples} ({len(results['missed_detections'])/total_examples:.1%})")
    print(f"  🚫 False Positives: {len(results['false_positives'])}/{total_examples} ({len(results['false_positives'])/total_examples:.1%})")
    
    polarity_acc = results['polarity_correct'] / results['polarity_total'] if results['polarity_total'] > 0 else 0
    print(f"  🎯 Polarity Accuracy: {results['polarity_correct']}/{results['polarity_total']} ({polarity_acc:.1%})")
    
    if detailed:
        # Show examples from each category
        print(f"\n🔍 DETAILED EXAMPLES:")
        
        # Perfect matches
        if results['perfect_matches']:
            print(f"\n✅ PERFECT MATCHES (showing first 2):")
            for i, match in enumerate(results['perfect_matches'][:2]):
                example = match['example']
                pred = match['prediction']
                print(f"  {i+1}. \"{example['text'][:80]}...\"")
                print(f"     Mechanisms: {', '.join(example['true_mechanisms'])}")
                print(f"     Confidence: {list(pred['mechanism_predictions'].values())[0]:.3f}")
        
        # Partial matches
        if results['partial_matches']:
            print(f"\n⚪ PARTIAL MATCHES (showing first 2):")
            for i, match in enumerate(results['partial_matches'][:2]):
                example = match['example']
                print(f"  {i+1}. \"{example['text'][:80]}...\"")
                print(f"     True: {', '.join(match['true'])}")
                print(f"     Predicted: {', '.join(match['predicted'])}")
                print(f"     Correct: {', '.join(match['intersection'])}")
                if match['missed']:
                    print(f"     Missed: {', '.join(match['missed'])}")
                if match['extra']:
                    print(f"     Extra: {', '.join(match['extra'])}")
        
        # Missed detections
        if results['missed_detections']:
            print(f"\n❌ MISSED DETECTIONS (showing first 2):")
            for i, miss in enumerate(results['missed_detections'][:2]):
                example = miss['example']
                print(f"  {i+1}. \"{example['text'][:80]}...\"")
                print(f"     Should have detected: {', '.join(example['true_mechanisms'])}")
                print(f"     Model predicted: None")