# ModernBERT Argument Mining Evaluation Pipeline

This script evaluates a fine-tuned ModernBERT model with LoRA adapters
for four argument mining tasks:
1. ADU Identification
2. ADU Classification  
3. Stance Classification

## ArgumentMiningEvaluator Class
- loads tokenizer and stores model path
- defines task_configs for each task, specifying label names and task type
- load_model_for_task: loads base model and attached the correct LoRA adapter for each task

### Benchmark Data Preperation
prepare_benchmark_data: loads, structures the evaluation data (claims, premises, categories)

### Prediction Methods
- predict_adu_identification: Predicts ADU boundaries in text (token classification)
- predict_adu_classification: Classifies a text as claim or premise
- predict_stance_classification: Predicts stance (pro/con) between a claim and a premise

### Evaluation
evaluate_all_tasks: Runs all predictions for all examples, collects predictions and ground truth, computes accuracy, F1 scores, stores error results

In [71]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import pandas as pd
from typing import List, Dict, Tuple, Any
import json
from pathlib import Path
import logging
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

In [72]:
import sys
sys.path.append('./ArgumentMining')
print(sys.path)

['/Users/lenap/Desktop/Armin', '/opt/anaconda3/lib/python312.zip', '/opt/anaconda3/lib/python3.12', '/opt/anaconda3/lib/python3.12/lib-dynload', '', '/opt/anaconda3/lib/python3.12/site-packages', '/opt/anaconda3/lib/python3.12/site-packages/aeosa', '/opt/anaconda3/lib/python3.12/site-packages/setuptools/_vendor', '/var/folders/7b/bpwbm4g168sb_l8mr76jv93w0000gn/T/tmp41r0utec', './ArgumentMining', './ArgumentMining', './ArgumentMining', './ArgumentMining']


In [73]:
from db.queries import get_quality_data
from db.quality_data import data as benchmark_data

In [74]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [75]:
class ArgumentMiningEvaluator:
    """
    Evaluator for ModernBERT argument mining models with LoRA adapters
    """
    
    def __init__(self, base_model_path: str, adapter_paths: Dict[str, str], device: str = None):
        """
        Initialize the evaluator
        
        Args:
            base_model_path: Path to the base ModernBERT model
            adapter_paths: Dictionary mapping task names to adapter paths
                          e.g., {'adu_identification': 'path/to/adu_id_adapter', ...}
            device: Device to use for inference
        """
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.base_model_path = base_model_path
        self.adapter_paths = adapter_paths
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)
        
        # Initialize models (will be loaded on demand)
        self.models = {}
        
        # Task configurations
        self.task_configs = {
            'adu_identification': {
                'labels': ['No', 'Yes'],  # 2 classes: No, Yes
                'task_type': 'token_classification'
        },
        'adu_classification': {
            'labels': ['claim', 'premise'],  # 2 classes
            'task_type': 'sequence_classification'
        },
        'stance_classification': {
            'labels': ['con', 'pro'],  # <-- match fine-tuning order!
            'task_type': 'sequence_classification'
        },
        'relationship_identification': {
            'labels': ['contradictory', 'supportive'],  # <-- match fine-tuning order!
            'task_type': 'sequence_classification'
        }
    }
        
        # Results storage
        self.results = {}
        
    def load_model_for_task(self, task_name: str):
        """Load model with specific LoRA adapter for a task"""
        if task_name in self.models:
            return self.models[task_name]
            
        logger.info(f"Loading model for task: {task_name}")
        
        # Load base model
        if self.task_configs[task_name]['task_type'] == 'token_classification':
            from transformers import AutoModelForTokenClassification
            base_model = AutoModelForTokenClassification.from_pretrained(
                self.base_model_path,
                num_labels=len(self.task_configs[task_name]['labels'])
            )
        else:
            base_model = AutoModelForSequenceClassification.from_pretrained(
                self.base_model_path,
                num_labels=len(self.task_configs[task_name]['labels'])
            )
        
        # Load LoRA adapter
        model = PeftModel.from_pretrained(base_model, self.adapter_paths[task_name])
        model = model.to(self.device)
        model.eval()
        
        self.models[task_name] = model
        return model
        
    def prepare_benchmark_data(self) -> List[Dict[str, Any]]:
        """
        Prepare benchmark data for evaluation
        Returns list of examples with ground truth labels
        """
        logger.info("Preparing benchmark data...")
        
        # Get benchmark data using your existing function or directly from database
        try:
            # Try to use your existing function
            from db.queries import get_quality_data
            claims, premises_lists, categories_lists = get_quality_data(benchmark_data)
        except ImportError:
            # Fallback: directly query database
            logger.info("Using fallback database query...")
            claims, premises_lists, categories_lists = self._get_benchmark_data_from_db()
        
        examples = []
        
        for i, (claim, premises_list, categories_list) in enumerate(zip(claims, premises_lists, categories_lists)):
            # Create example for each claim-premise pair
            for j, (premise, category) in enumerate(zip(premises_list, categories_list)):
                example = {
                    'claim_id': claim.id,
                    'premise_id': premise.id,
                    'claim_text': claim.text,
                    'premise_text': premise.text,
                    'stance': category,  # stance_pro or stance_con
                    'claim_type': 'claim',  # Ground truth for ADU classification
                    'premise_type': 'premise',  # Ground truth for ADU classification
                }
                examples.append(example)
                
        return examples
    
    def _get_benchmark_data_from_db(self):
        """
        Fallback method to get benchmark data directly from database
        """
        from .db import get_session
        from .models import ADU, Relationship
        from collections import defaultdict
        
        with get_session() as session:
            claim_ids = list(benchmark_data.keys())
            premise_ids = list({pid for pids in benchmark_data.values() for pid in pids})
            
            # Get claims and premises
            claims = session.query(ADU).filter(ADU.id.in_(claim_ids)).all()
            premises = session.query(ADU).filter(ADU.id.in_(premise_ids)).all()
            
            claims_by_id = {c.id: c for c in claims}
            premises_by_id = {p.id: p for p in premises}
            
            # Get relationships
            rows = (
                session
                .query(Relationship.from_adu_id, Relationship.to_adu_id, Relationship.category)
                .filter(
                    Relationship.to_adu_id.in_(claim_ids),
                    Relationship.from_adu_id.in_(premise_ids)
                )
                .all()
            )
            
            category_lookup = {(from_id, to_id): cat for from_id, to_id, cat in rows}
            
            output_claims = []
            output_premises = []
            output_categories = []
            
            for claim_id, premise_list in benchmark_data.items():
                claim = claims_by_id.get(claim_id)
                if not claim:
                    continue
                    
                current_premises = []
                current_categories = []
                
                for pid in premise_list:
                    premise = premises_by_id.get(pid)
                    if not premise:
                        continue
                        
                    current_premises.append(premise)
                    category = category_lookup.get((pid, claim_id), None)
                    current_categories.append(category)
                
                output_claims.append(claim)
                output_premises.append(current_premises)
                output_categories.append(current_categories)
            
            return output_claims, output_premises, output_categories
        
    def predict_adu_identification(self, text: str) -> List[str]:
        """
        Predict ADU boundaries in text using token classification
        Returns list of BIO tags
        """
        model = self.load_model_for_task('adu_identification')
        
        # Tokenize
        inputs = self.tokenizer(
            text, 
            return_tensors="pt", 
            truncation=True, 
            padding=True,
            max_length=512
        ).to(self.device)
        
        # Predict
        with torch.no_grad():
            outputs = model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=-1)
        
        # Convert to labels
        labels = self.task_configs['adu_identification']['labels']
        pred_labels = [labels[p] for p in predictions[0].cpu().numpy()]
        
        return pred_labels
        
    def predict_adu_classification(self, text: str) -> str:
        """
        Classify ADU type (claim/premise)
        """
        model = self.load_model_for_task('adu_classification')
        
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512
        ).to(self.device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            prediction = torch.argmax(outputs.logits, dim=-1)
            
        labels = self.task_configs['adu_classification']['labels']
        return labels[prediction.item()]
        
    def predict_stance_classification(self, claim_text: str, premise_text: str) -> str:
        """
        Classify stance between claim and premise
        """
        model = self.load_model_for_task('stance_classification')
        
        # Combine texts (adjust format based on your training)
        combined_text = f"[CLS] {claim_text} [SEP] {premise_text} [SEP]"
        
        inputs = self.tokenizer(
            combined_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512
        ).to(self.device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            prediction = torch.argmax(outputs.logits, dim=-1)
            
        labels = self.task_configs['stance_classification']['labels']
        return labels[prediction.item()]
        
    def evaluate_all_tasks(self) -> Dict[str, Any]:
        """
        Run evaluation on all tasks and return comprehensive results
        """
        logger.info("Starting comprehensive evaluation...")
        
        # Prepare data
        examples = self.prepare_benchmark_data()
        
        # Initialize prediction storage
        predictions = {
            'adu_classification_claim': [],
            'adu_classification_premise': [], 
            'stance_classification': []
        }
        
        ground_truth = {
            'adu_classification_claim': [],
            'adu_classification_premise': [],
            'stance_classification': []
        }
        
        # Run predictions
        logger.info(f"Processing {len(examples)} examples...")
        for example in tqdm(examples):
            # ADU Classification
            claim_pred = self.predict_adu_classification(example['claim_text'])
            premise_pred = self.predict_adu_classification(example['premise_text'])
            
            predictions['adu_classification_claim'].append(claim_pred)
            predictions['adu_classification_premise'].append(premise_pred)
            ground_truth['adu_classification_claim'].append(example['claim_type'])
            ground_truth['adu_classification_premise'].append(example['premise_type'])
            
            # Stance Classification
            stance_pred = self.predict_stance_classification(
                example['claim_text'], 
                example['premise_text']
            )
            predictions['stance_classification'].append(stance_pred)

            # Map ground truth to model's expected labels
            if example['stance'] == 'stance_pro':
                stance_gt = 'pro'
            elif example['stance'] == 'stance_con':
                stance_gt = 'con'
            else:
                stance_gt = None
            ground_truth['stance_classification'].append(stance_gt)
            
        # Calculate metrics
        results = {}
        
        for task_name in predictions.keys():
            y_true = ground_truth[task_name]
            y_pred = predictions[task_name]

            filtered = [(yt, yp) for yt, yp in zip(y_true, y_pred) if yt is not None and yp is not None]
            if not filtered:
                continue
            y_true_filtered, y_pred_filtered = zip(*filtered)

            results[task_name] = {
                'accuracy': accuracy_score(y_true_filtered, y_pred_filtered),
                'f1_macro': f1_score(y_true_filtered, y_pred_filtered, average='macro'),
                'f1_weighted': f1_score(y_true_filtered, y_pred_filtered, average='weighted'),
                'classification_report': classification_report(y_true_filtered, y_pred_filtered, output_dict=True),
                'confusion_matrix': confusion_matrix(y_true_filtered, y_pred_filtered).tolist(),
                'predictions': list(y_pred_filtered),
                'ground_truth': list(y_true_filtered)
        }
            
        self.results = results
        return results
        
    def print_results_summary(self):
        """Print a summary of evaluation results"""
        print("\n" + "="*80)
        print("ARGUMENT MINING EVALUATION RESULTS")
        print("="*80)
        
        for task_name, metrics in self.results.items():
            print(f"\n{task_name.upper().replace('_', ' ')}")
            print("-" * 50)
            print(f"Accuracy: {metrics['accuracy']:.4f}")
            print(f"F1 (Macro): {metrics['f1_macro']:.4f}")
            print(f"F1 (Weighted): {metrics['f1_weighted']:.4f}")
            
            # Print per-class metrics
            report = metrics['classification_report']
            for label, scores in report.items():
                if isinstance(scores, dict) and label not in ['accuracy', 'macro avg', 'weighted avg']:
                    print(f"  {label}: Precision={scores['precision']:.3f}, "
                          f"Recall={scores['recall']:.3f}, F1={scores['f1-score']:.3f}")
                          
    def save_detailed_results(self, output_path: str):
        """Save detailed results to JSON file"""
        output_file = Path(output_path)
        output_file.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_file, 'w') as f:
            json.dump(self.results, f, indent=2, default=str)
            
        logger.info(f"Detailed results saved to {output_file}")
        
    def create_error_analysis(self) -> pd.DataFrame:
        """Create error analysis DataFrame"""
        examples = self.prepare_benchmark_data()
        
        error_data = []
        
        # for stance classification results
        stance_preds = self.results['stance_classification']['predictions']
        stance_gt = self.results['stance_classification']['ground_truth']
        
        for i, (example, pred, gt) in enumerate(zip(examples, stance_preds, stance_gt)):
            if pred != gt:
                error_data.append({
                    'claim_id': example['claim_id'],
                    'premise_id': example['premise_id'],
                    'claim_text': example['claim_text'][:100] + '...',
                    'premise_text': example['premise_text'][:100] + '...',
                    'predicted': pred,
                    'ground_truth': gt,
                    'error_type': f"{gt} -> {pred}"
                })
                
        return pd.DataFrame(error_data)



In [76]:
def main():
    """
    Main evaluation function
    """
    
    BASE_MODEL_PATH = "answerdotai/ModernBERT-base"  # Hugging Face model
    ADAPTER_PATHS = {
        'adu_identification': "./argument-mining-modernbert-all/argument-mining-modernbert-adu_identification",
        'adu_classification': "./argument-mining-modernbert-all/argument-mining-modernbert-adu_classification", 
        'stance_classification': "./argument-mining-modernbert-all/argument-mining-modernbert-stance_classification"
    }
    

    
    # Initialize evaluator
    evaluator = ArgumentMiningEvaluator(
        base_model_path=BASE_MODEL_PATH,
        adapter_paths=ADAPTER_PATHS
    )
    
    # Run evaluation
    results = evaluator.evaluate_all_tasks()
    
    # Print results
    evaluator.print_results_summary()
    
    # Save detailed results
    evaluator.save_detailed_results("evaluation_results.json")
    
    # Create error analysis
    error_df = evaluator.create_error_analysis()
    error_df.to_csv("error_analysis.csv", index=False)
    
    print(f"\nError analysis saved to error_analysis.csv")
    print(f"Found {len(error_df)} errors out of {len(evaluator.results['stance_classification']['predictions'])} predictions")

In [77]:
if __name__ == "__main__":
    main()

INFO:__main__:Starting comprehensive evaluation...
INFO:__main__:Preparing benchmark data...
INFO:__main__:Processing 301 examples...
  0%|          | 0/301 [00:00<?, ?it/s]INFO:__main__:Loading model for task: adu_classification
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:__main__:Loading model for task: stance_classification
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 301/301 [06:58<00:00,  1.39s/it]  
  _warn_prf(average, modifier, f"{metric


ARGUMENT MINING EVALUATION RESULTS

ADU CLASSIFICATION CLAIM
--------------------------------------------------
Accuracy: 0.9767
F1 (Macro): 0.4941
F1 (Weighted): 0.9882
  claim: Precision=1.000, Recall=0.977, F1=0.988
  premise: Precision=0.000, Recall=0.000, F1=0.000

ADU CLASSIFICATION PREMISE
--------------------------------------------------
Accuracy: 0.8804
F1 (Macro): 0.4682
F1 (Weighted): 0.9364
  claim: Precision=0.000, Recall=0.000, F1=0.000
  premise: Precision=1.000, Recall=0.880, F1=0.936

STANCE CLASSIFICATION
--------------------------------------------------
Accuracy: 0.5671
F1 (Macro): 0.5640
F1 (Weighted): 0.5640
  con: Precision=0.581, Recall=0.483, F1=0.527
  pro: Precision=0.557, Recall=0.651, F1=0.601

Error analysis saved to error_analysis.csv
Found 129 errors out of 298 predictions
