# Enhanced BioBERT for Chemical-Disease Relation Extraction

This notebook implements an enhanced BioBERT-based approach for chemical-disease relation extraction in biomedical literature. It includes the following components:

1. Data processing and loading
2. Baseline CNN model
3. Original BioBERT implementation
4. Enhanced BioBERT with recall optimization
5. Longformer model for long-document processing
6. Knowledge-enhanced BioBERT with CTD integration
7. Data augmentation techniques
8. Training and evaluation pipelines
9. Experiment execution and result analysis

## 1. Setup and Imports

In [None]:
# Install required packages
%pip install torch transformers pandas numpy networkx scikit-learn nltk spacy tqdm matplotlib seaborn wandb
%pip install gensim  

In [None]:
# Import basic libraries
import os
import json
import re
import pickle
import random
from collections import defaultdict
from tqdm.notebook import tqdm

# Data processing and analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, classification_report

# PyTorch and transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    BertTokenizer,
    BertModel,
    BertForSequenceClassification,
    LongformerModel,
    LongformerTokenizer,
    get_linear_schedule_with_warmup
)

# For knowledge graph processing
import networkx as nx

# NLTK for data augmentation
import nltk
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer

# Set random seeds for reproducibility
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

set_seed(42)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Download NLTK resources
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')
    nltk.download('wordnet')

## 2. Data Processing

In [None]:
class CDRDataProcessor:
    """
    Processor for the BioCreative V CDR dataset for chemical-disease relation extraction.
    The dataset consists of PubMed abstracts annotated with chemicals, diseases,
    and their binary relations (does a chemical induce a disease).
    """

    def __init__(self, data_dir):
        """
        Args:
            data_dir: Directory containing the dataset files
        """
        self.data_dir = data_dir

    def read_file(self, file_path):
        """
        Read the dataset file in PubTator format

        Args:
            file_path: Path to the dataset file

        Returns:
            List of dictionaries with abstract info, entities, and relations
        """
        documents = {}

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                parts = line.split('\t')
                doc_id = parts[0]

                # Initialize document if not exists
                if doc_id not in documents:
                    documents[doc_id] = {
                        'id': doc_id,
                        'title': '',
                        'abstract': '',
                        'chemicals': [],
                        'diseases': [],
                        'relations': []
                    }

                # Title or abstract text
                if len(parts) >= 2 and parts[1] in ['t', 'a']:
                    section_type = 'title' if parts[1] == 't' else 'abstract'
                    documents[doc_id][section_type] = parts[2] if len(parts) > 2 else ''

                # Entity annotation
                elif len(parts) >= 6:
                    start, end, mention, entity_type = parts[1:5]

                    # Handle the case of multiple concept IDs (pipe-separated)
                    concept_ids = parts[5].split('|')

                    entity = {
                        'id': concept_ids[0],  # Primary ID
                        'all_ids': concept_ids,  # All IDs
                        'type': entity_type,
                        'mention': mention,
                        'start': int(start),
                        'end': int(end)
                    }

                    if entity_type == 'Chemical':
                        documents[doc_id]['chemicals'].append(entity)
                    elif entity_type == 'Disease':
                        documents[doc_id]['diseases'].append(entity)

                # Relation annotation
                elif len(parts) == 4 and parts[1] == 'CID':
                    documents[doc_id]['relations'].append({
                        'chemical_id': parts[2],
                        'disease_id': parts[3]
                    })

        # Process documents to add text field combining title and abstract
        for doc_id, doc in documents.items():
            doc['text'] = doc['title'] + ' ' + doc['abstract']

        return list(documents.values())

    def get_train_examples(self):
        """Get training examples"""
        return self._create_examples(
            self.read_file(os.path.join(self.data_dir, "CDR_TrainingSet.PubTator.txt")),
            "train"
        )

    def get_dev_examples(self):
        """Get development examples"""
        return self._create_examples(
            self.read_file(os.path.join(self.data_dir, "CDR_DevelopmentSet.PubTator.txt")),
            "dev"
        )

    def get_test_examples(self):
        """Get test examples"""
        return self._create_examples(
            self.read_file(os.path.join(self.data_dir, "CDR_TestSet.PubTator.txt")),
            "test"
        )

    def _create_examples(self, documents, set_type):
        """
        Create examples for the training and dev sets.

        For each document, generate examples for all possible chemical-disease pairs.
        Positive examples are pairs that have a CID relation.
        """
        examples = []

        for doc in documents:
            # Get the text for the document (title + abstract)
            text = doc['text']

            # Create a set of positive relations (chemical_id, disease_id)
            positive_rels = {(r['chemical_id'], r['disease_id']) for r in doc['relations']}

            # Generate examples for all possible chemical-disease pairs
            for chemical in doc['chemicals']:
                for disease in doc['diseases']:
                    # For diseases with multiple concept IDs, check each one
                    all_disease_ids = disease.get('all_ids', [disease['id']])

                    # Check if this pair has a CID relation with any disease ID
                    is_relation = 0
                    for disease_id in all_disease_ids:
                        if (chemical['id'], disease_id) in positive_rels:
                            is_relation = 1
                            break

                    # Find sentences containing the entities
                    chem_sentences = self._find_entity_sentences(text, chemical)
                    disease_sentences = self._find_entity_sentences(text, disease)

                    # Create the example
                    example = {
                        'doc_id': doc['id'],
                        'text': text,
                        'title': doc['title'],
                        'abstract': doc['abstract'],
                        'chemical': chemical,
                        'disease': disease,
                        'label': is_relation,
                        'chemical_sentences': chem_sentences,
                        'disease_sentences': disease_sentences,
                        'set_type': set_type
                    }

                    examples.append(example)

        return examples

    def _find_entity_sentences(self, text, entity):
        """
        Find all sentences that contain the entity mention

        Args:
            text: Full document text
            entity: Entity dictionary with mention, start, end

        Returns:
            List of sentences containing the entity
        """
        # Simple sentence splitting - this is a basic approach
        # In a production system, use a more sophisticated sentence splitter
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)

        entity_start = entity['start']
        entity_end = entity['end']

        # Find which sentences contain this span
        containing_sentences = []
        current_pos = 0

        for sentence in sentences:
            sentence_start = current_pos
            sentence_end = current_pos + len(sentence)

            # Check if entity span overlaps with sentence span
            if not (entity_end <= sentence_start or entity_start >= sentence_end):
                containing_sentences.append(sentence.strip())

            # Move position past this sentence and the space after it
            current_pos = sentence_end + 1

        return containing_sentences

In [None]:
# Define function to load data or process it if not already done
def load_or_process_data(data_dir):
    """
    Load processed data if available, or process from raw files
    """
    processed_dir = os.path.join(data_dir, "processed")
    os.makedirs(processed_dir, exist_ok=True)

    # Check if processed files exist
    train_file = os.path.join(processed_dir, "train_examples.pkl")
    dev_file = os.path.join(processed_dir, "dev_examples.pkl")
    test_file = os.path.join(processed_dir, "test_examples.pkl")

    if os.path.exists(train_file) and os.path.exists(dev_file) and os.path.exists(test_file):
        print("Loading processed data...")
        with open(train_file, 'rb') as f:
            train_examples = pickle.load(f)
        with open(dev_file, 'rb') as f:
            dev_examples = pickle.load(f)
        with open(test_file, 'rb') as f:
            test_examples = pickle.load(f)
    else:
        print("Processing data from raw files...")
        processor = CDRDataProcessor(data_dir)
        train_examples = processor.get_train_examples()
        dev_examples = processor.get_dev_examples()
        test_examples = processor.get_test_examples()

        # Save processed data
        with open(train_file, 'wb') as f:
            pickle.dump(train_examples, f)
        with open(dev_file, 'wb') as f:
            pickle.dump(dev_examples, f)
        with open(test_file, 'wb') as f:
            pickle.dump(test_examples, f)

    # Print dataset statistics
    train_pos = sum(1 for ex in train_examples if ex['label'] == 1)
    dev_pos = sum(1 for ex in dev_examples if ex['label'] == 1)
    test_pos = sum(1 for ex in test_examples if ex['label'] == 1)

    print(f"Number of training examples: {len(train_examples)} ({train_pos} positive, {len(train_examples) - train_pos} negative)")
    print(f"Number of development examples: {len(dev_examples)} ({dev_pos} positive, {len(dev_examples) - dev_pos} negative)")
    print(f"Number of test examples: {len(test_examples)} ({test_pos} positive, {len(test_examples) - test_pos} negative)")

    return train_examples, dev_examples, test_examples

In [None]:
# Set the path to your CDR dataset
data_dir = "/content/drive/MyDrive/REBL/CDR_Data/CDR.Corpus.v010516"  # Update this to your dataset path

# Load or process the data
train_examples, dev_examples, test_examples = load_or_process_data(data_dir)

## 3. Dataset Classes

In [None]:
class CDRBaselineDataset(Dataset):
    """
    Dataset for the baseline CNN model
    """

    def __init__(self, examples, tokenizer, max_length=128):
        """
        Initialize dataset

        Args:
            examples: List of examples
            tokenizer: Tokenizer for text
            max_length: Maximum sequence length
        """
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]

        # Prepare text (simplified for this example)
        # In a real implementation, include more context and entity markers
        chemical = example['chemical']['mention']
        disease = example['disease']['mention']

        # Get sentences containing both entities if possible, otherwise use separate sentences
        shared_sentences = set(example['chemical_sentences']) & set(example['disease_sentences'])
        if shared_sentences:
            text = list(shared_sentences)[0]
        else:
            # Use first sentence for each entity
            chem_text = example['chemical_sentences'][0] if example['chemical_sentences'] else ""
            disease_text = example['disease_sentences'][0] if example['disease_sentences'] else ""
            text = chem_text + " " + disease_text

        # Add entity markers
        text = f"Chemical: {chemical} Disease: {disease} Context: {text}"

        # Tokenize
        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        # Get label
        label = example['label']

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long)
        }


class CDRBioBERTDataset(Dataset):
    """
    Dataset for the BioBERT relation extraction model
    """

    def __init__(self, examples, tokenizer, max_length=512, document_level=True):
        """
        Initialize dataset

        Args:
            examples: List of examples
            tokenizer: BioBERT tokenizer
            max_length: Maximum sequence length
            document_level: Whether to use whole document or sentence-level input
        """
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.document_level = document_level

    def mark_entities_in_text(self, text, entities):
        """
        Mark entities in text while handling overlaps properly

        Args:
            text: Text to mark entities in
            entities: List of entity dictionaries with start, end, mention, type

        Returns:
            Text with marked entities
        """
        # Sort entities by start position (descending) to avoid issues with overlapping spans
        sorted_entities = sorted(entities, key=lambda e: e['start'], reverse=True)

        # Mark each entity
        marked_text = text
        for entity in sorted_entities:
            start = entity['start']
            end = entity['end']
            mention = entity['mention']
            entity_type = entity['type']

            # Create marker based on entity type
            marker_start = "[CHEM] " if entity_type == 'Chemical' else "[DISE] "
            marker_end = " [/CHEM]" if entity_type == 'Chemical' else " [/DISE]"

            # Replace the entity with marked version
            marked_text = marked_text[:start] + marker_start + mention + marker_end + marked_text[end:]

        return marked_text

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]

        chemical = example['chemical']['mention']
        disease = example['disease']['mention']

        # Determine context to use (document or sentence)
        if self.document_level:
            # Ensure we preserve title and abstract separation
            if 'title' in example and 'abstract' in example:
                # Format with explicit title and abstract markers
                text = f"Title: {example['title']} [SEP] Abstract: {example['abstract']}"
            else:
                # Fall back to using the full text
                text = example['text']

            # Use the helper method to mark entities
            entities_to_mark = [
                {'start': example['chemical']['start'],
                 'end': example['chemical']['end'],
                 'mention': example['chemical']['mention'],
                 'type': 'Chemical'},
                {'start': example['disease']['start'],
                 'end': example['disease']['end'],
                 'mention': example['disease']['mention'],
                 'type': 'Disease'}
            ]

            # Need to adjust start/end positions if using title/abstract format
            if 'title' in example and 'abstract' in example:
                # Adjust positions based on the new format
                title_prefix_len = len("Title: ")
                abstract_prefix_len = len(" [SEP] Abstract: ")
                title_len = len(example['title'])

                for entity in entities_to_mark:
                    if entity['start'] < title_len:
                        # Entity is in the title
                        entity['start'] += title_prefix_len
                        entity['end'] += title_prefix_len
                    else:
                        # Entity is in the abstract
                        entity['start'] += title_prefix_len + abstract_prefix_len
                        entity['end'] += title_prefix_len + abstract_prefix_len

            text = self.mark_entities_in_text(text, entities_to_mark)
        else:
            # Use sentence-level context
            # Get sentences containing both entities if possible
            shared_sentences = set(example['chemical_sentences']) & set(example['disease_sentences'])
            if shared_sentences:
                sent_text = list(shared_sentences)[0]
                # Mark entities in the sentence (simplified)
                sent_text = sent_text.replace(chemical, f"[CHEM] {chemical} [/CHEM]")
                sent_text = sent_text.replace(disease, f"[DISE] {disease} [/DISE]")
                text = sent_text
            else:
                # Use both entity sentences if no shared sentence
                chem_text = example['chemical_sentences'][0] if example['chemical_sentences'] else ""
                disease_text = example['disease_sentences'][0] if example['disease_sentences'] else ""

                # Mark entities
                chem_text = chem_text.replace(chemical, f"[CHEM] {chemical} [/CHEM]")
                disease_text = disease_text.replace(disease, f"[DISE] {disease} [/DISE]")

                text = chem_text + " " + disease_text

        # Tokenize
        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        # Get label
        label = example['label']

        # Include the chemical and disease IDs for knowledge-enhanced models
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'token_type_ids': encoding['token_type_ids'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long),
            'chemical_id': example['chemical']['id'],
            'disease_id': example['disease']['id']
        }


class CDRLongformerDataset(Dataset):
    """
    Dataset class for the CDR corpus using Longformer tokenizer
    """
    def __init__(self, examples, tokenizer, max_length=4096):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]

        # Get document text and entity information
        text = f"Title: {example['title']} Abstract: {example['abstract']}"
        chemical = example['chemical']['mention']
        disease = example['disease']['mention']

        # Add entity markers to highlight the entities
        # This helps the model focus on these entities using global attention
        text = text.replace(chemical, f"[CHEM] {chemical} [/CHEM]")
        text = text.replace(disease, f"[DISE] {disease} [/DISE]")

        # Tokenize with Longformer tokenizer
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Create global attention mask for important tokens
        global_attention_mask = torch.zeros_like(encoding["input_ids"])
        global_attention_mask[:, 0] = 1  # CLS token gets global attention

        # Find positions of entity markers and give them global attention
        # In a real implementation, you'd need to handle this properly
        # Since we don't have actual token IDs for our markers, we'll approximate
        for i, token_id in enumerate(encoding["input_ids"][0]):
            # This is a simplification, you would actually need to get the true token IDs
            # for your entity markers from the tokenizer
            token = self.tokenizer.decode([token_id.item()])
            if any(marker in token for marker in ["[CHEM]", "[/CHEM]", "[DISE]", "[/DISE]"]):
                global_attention_mask[:, i] = 1

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "global_attention_mask": global_attention_mask.squeeze(),
            "label": torch.tensor(example["label"], dtype=torch.long),
            "chemical_id": example['chemical']['id'],
            "disease_id": example['disease']['id']
        }

## 4. Model Implementations

### 4.1 Baseline CNN Model

In [None]:
class CNNBaseline(nn.Module):
    """
    CNN model for relation classification.

    Input: A text sequence with marked chemical and disease entities.
    Output: Binary classification (is there a CID relation or not).
    """

    def __init__(self, vocab_size, embedding_dim=100, filters=100, filter_sizes=(3, 4, 5), dropout=0.5):
        """
        Initialize CNN model

        Args:
            vocab_size: Size of vocabulary
            embedding_dim: Dimension of word embeddings
            filters: Number of filters
            filter_sizes: Size of filters (window sizes)
            dropout: Dropout probability
        """
        super(CNNBaseline, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # Conv layers
        self.convs = nn.ModuleList([
            nn.Conv1d(embedding_dim, filters, kernel_size=fs)
            for fs in filter_sizes
        ])

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(len(filter_sizes) * filters, 2)  # Binary classification

    def forward(self, x):
        """
        Forward pass

        Args:
            x: Input tensor [batch_size, seq_len]

        Returns:
            Logits for binary classification
        """
        # Embedding layer
        x = self.embedding(x)  # [batch_size, seq_len, embedding_dim]
        x = x.permute(0, 2, 1)  # [batch_size, embedding_dim, seq_len]

        # Conv and max-pooling layers
        conv_results = []
        for conv in self.convs:
            conved = F.relu(conv(x))  # [batch_size, filters, seq_len - filter_size + 1]
            pooled = F.max_pool1d(conved, conved.shape[2])  # [batch_size, filters, 1]
            conv_results.append(pooled.squeeze(2))  # [batch_size, filters]

        # Concatenate results
        x = torch.cat(conv_results, dim=1)  # [batch_size, filters * len(filter_sizes)]

        # Dropout and fully connected layer
        x = self.dropout(x)
        x = self.fc(x)  # [batch_size, 2]

        return x

### 4.2 BioBERT Model

In [None]:
class BioBERTForRelationExtraction(nn.Module):
    """
    BioBERT model for chemical-disease relation extraction.

    This model fine-tunes a pre-trained BioBERT model to predict
    if there is a chemical-induced disease relation between two entities.
    """

    def __init__(self, pretrained_model="dmis-lab/biobert-v1.1", num_labels=2):
        """
        Initialize with a pre-trained BioBERT model

        Args:
            pretrained_model: Name of pre-trained BioBERT model
            num_labels: Number of output labels (2 for binary classification)
        """
        super(BioBERTForRelationExtraction, self).__init__()

        # Load pre-trained BioBERT model
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        """
        Forward pass

        Args:
            input_ids: Tensor of token ids
            attention_mask: Tensor of attention mask
            token_type_ids: Tensor of token type ids
            labels: Tensor of gold labels (optional)

        Returns:
            Logits and loss (if labels provided)
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        # Get the [CLS] token representation
        pooled_output = outputs.pooler_output

        # Apply dropout and classification layer
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))

        return loss, logits

### 4.3 Enhanced BioBERT with Recall Optimization

In [None]:
class ImprovedBioBERTForRelationExtraction(nn.Module):
    """
    Enhanced BioBERT model for relation extraction with recall-focused improvements
    """
    def __init__(self, pretrained_model="dmis-lab/biobert-v1.1", num_labels=2,
                 focal_loss_gamma=2.0, class_weights=None):
        super(ImprovedBioBERTForRelationExtraction, self).__init__()

        # Load pre-trained BioBERT model
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)

        # Multiple classification heads for ensemble-like behavior
        self.classifier1 = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.classifier2 = nn.Linear(self.bert.config.hidden_size, num_labels)

        # Store parameters for focal loss
        self.focal_loss_gamma = focal_loss_gamma

        # Class weights for handling imbalance (default: equal weights)
        self.class_weights = torch.tensor([1.0, 1.0]) if class_weights is None else class_weights

    def focal_loss(self, logits, labels):
        """
        Implements focal loss to focus more on hard-to-classify examples
        """
        # Standard cross entropy with class weights
        ce_loss = F.cross_entropy(
            logits.view(-1, 2),
            labels.view(-1),
            weight=self.class_weights.to(logits.device),
            reduction='none'
        )

        # Get probabilities for focal loss calculation
        pt = torch.exp(-ce_loss)

        # Apply focal loss formula: (1-pt)^gamma * ce_loss
        focal_loss = ((1 - pt) ** self.focal_loss_gamma) * ce_loss

        return focal_loss.mean()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None,
                threshold=0.3):  # Lower threshold to favor recall
        """
        Forward pass with ensemble voting and adjustable threshold
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        # Get the [CLS] token representation
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        # Get logits from multiple classification heads
        logits1 = self.classifier1(pooled_output)
        logits2 = self.classifier2(pooled_output)

        # Ensemble the predictions with averaging
        logits = (logits1 + logits2) / 2.0

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            # Use focal loss instead of standard cross entropy
            loss = self.focal_loss(logits, labels)

        return loss, logits

### 4.4 Longformer for Document Processing

In [None]:
class LongformerForRelationExtraction(nn.Module):
    """
    Longformer model for document-level relation extraction with support for much longer contexts
    """
    def __init__(self, pretrained_model="allenai/longformer-base-4096", num_labels=2):
        super(LongformerForRelationExtraction, self).__init__()

        # Load pre-trained Longformer model with 4096 token support
        try:
            self.longformer = LongformerModel.from_pretrained(pretrained_model)
        except Exception as e:
            print(f"Error loading Longformer model: {e}")
            print("Falling back to dummy initialization for demonstration purposes")
            # Create a dummy model for this notebook to run without the actual model
            from transformers import AutoConfig
            config = AutoConfig.from_pretrained(pretrained_model)
            self.longformer = LongformerModel(config)

        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.longformer.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, global_attention_mask=None, labels=None):
        """
        Forward pass with global attention on special tokens

        Args:
            input_ids: Token IDs
            attention_mask: Attention mask for padded tokens
            global_attention_mask: Mask identifying tokens with global attention (1) or local attention (0)
            labels: Gold labels for loss calculation
        """
        # Create global attention mask if not provided
        # We set global attention on [CLS], entity markers, and separator tokens
        if global_attention_mask is None:
            # Initialize with zeros
            global_attention_mask = torch.zeros_like(input_ids)

            # Set global attention on [CLS] token
            global_attention_mask[:, 0] = 1

        # Forward pass through Longformer
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask
        )

        # Use CLS token representation for classification
        sequence_output = outputs.last_hidden_state
        cls_output = sequence_output[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))

        return loss, logits

### 4.5 Knowledge-Enhanced BioBERT Model

In [None]:
class CTDKnowledgeBase:
    """
    Wrapper for the Comparative Toxicogenomics Database (CTD) knowledge base
    """
    def __init__(self, ctd_file_path="/content/drive/MyDrive/REBL/CDR_Data/CTD/CTD_chemicals_diseases.tsv", dummy_mode=True):
        self.dummy_mode = dummy_mode

        # For demo purposes, we'll use a dummy knowledge base
        # In a real implementation, load the actual CTD database
        self.ctd_relations = set()
        self.embeddings = {}

        if not dummy_mode and ctd_file_path and os.path.exists(ctd_file_path):
            # Load CTD chemical-disease associations
            self.ctd_data = pd.read_csv(ctd_file_path, sep='\t')

            # Build the knowledge graph and embeddings
            print("Building knowledge graph...")
            self.graph = self._build_knowledge_graph()
            print("Generating embeddings...")
            self.embeddings = self._generate_embeddings()
        else:
            print("Using dummy CTD knowledge base for demonstration")
            # Create a small dummy graph for demonstration
            self.graph = nx.Graph()
            self.graph.add_node("C_MESH:D008687", type="chemical")
            self.graph.add_node("D_MESH:D006973", type="disease")
            self.graph.add_edge("C_MESH:D008687", "D_MESH:D006973")

            # Create dummy embeddings
            self.embeddings = {
                "C_MESH:D008687": np.random.randn(128),
                "D_MESH:D006973": np.random.randn(128)
            }

    def _build_knowledge_graph(self):
        """Build a NetworkX graph from CTD data"""
        G = nx.Graph()

        # Add all nodes and edges
        # In a real implementation, this would parse the full CTD database

        return G

    def _generate_embeddings(self, dim=128):
        """Generate embeddings for each node using node2vec or similar algorithm"""
        # In a real implementation, this would use node2vec or a similar algorithm
        embeddings = {}
        for node in self.graph.nodes():
            embeddings[node] = np.random.randn(dim)
        return embeddings

    def get_chemical_embedding(self, chemical_id):
        """Get embedding for a chemical"""
        key = f"C_{chemical_id}"
        if key in self.embeddings:
            return self.embeddings[key]
        else:
            # Return random vector for unknown chemicals in dummy mode
            if self.dummy_mode:
                return np.random.randn(128)  # Random embedding for demonstration
            # Return zero vector for unknown chemicals in real mode
            return np.zeros(list(self.embeddings.values())[0].shape[0] if self.embeddings else 128)

    def get_disease_embedding(self, disease_id):
        """Get embedding for a disease"""
        key = f"D_{disease_id}"
        if key in self.embeddings:
            return self.embeddings[key]
        else:
            # Return random vector for unknown diseases in dummy mode
            if self.dummy_mode:
                return np.random.randn(128)  # Random embedding for demonstration
            # Return zero vector for unknown diseases in real mode
            return np.zeros(list(self.embeddings.values())[0].shape[0] if self.embeddings else 128)

    def get_path_exists(self, chemical_id, disease_id):
        """Check if a path exists between chemical and disease in the graph"""
        # Return random value in dummy mode
        if self.dummy_mode:
            return float(random.random() > 0.7)  # 30% chance of having a relation

        try:
            path = nx.shortest_path(
                self.graph,
                source=f"C_{chemical_id}",
                target=f"D_{disease_id}"
            )
            return 1.0 if path else 0.0
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return 0.0

In [None]:
class KnowledgeEnhancedBioBERT(nn.Module):
    """
    BioBERT model enhanced with CTD knowledge base features
    """
    def __init__(self, pretrained_model="dmis-lab/biobert-v1.1", num_labels=2,
                 knowledge_base=None, knowledge_dim=128):
        super(KnowledgeEnhancedBioBERT, self).__init__()

        # Load pre-trained BioBERT model
        self.bert = BertModel.from_pretrained(pretrained_model)

        # Knowledge base
        self.kb = knowledge_base
        self.knowledge_dim = knowledge_dim

        # Layers for combining text and knowledge features
        self.dropout = nn.Dropout(0.1)

        # If using knowledge base features
        if self.kb:
            # Fully connected layer for BioBERT output
            self.bert_fc = nn.Linear(self.bert.config.hidden_size, 512)

            # Fully connected layer for knowledge base features
            self.kb_fc = nn.Linear(knowledge_dim * 2, 512)  # *2 for concatenating chem and disease embeddings

            # Attention layer for combining features
            self.attention = nn.Linear(512, 1)

            # Final classification layer
            self.classifier = nn.Linear(512, num_labels)
        else:
            # Standard classifier if no knowledge base
            self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None,
                chemical_ids=None, disease_ids=None):
        """
        Forward pass with both text and knowledge features
        """
        # Process text with BioBERT
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        # Get the [CLS] token representation
        text_features = outputs.pooler_output
        text_features = self.dropout(text_features)

        # If using knowledge base
        if self.kb and chemical_ids is not None and disease_ids is not None:
            # Process BioBERT features
            bert_features = self.bert_fc(text_features)

            # Process knowledge base features
            kb_features_list = []

            for i in range(len(chemical_ids)):
                # Get chemical and disease embeddings
                chem_emb = self.kb.get_chemical_embedding(chemical_ids[i])
                disease_emb = self.kb.get_disease_embedding(disease_ids[i])

                # Concatenate embeddings
                kb_features_list.append(np.concatenate([chem_emb, disease_emb]))

            # Convert to tensor
            kb_features = torch.FloatTensor(np.array(kb_features_list)).to(input_ids.device)
            kb_features = self.kb_fc(kb_features)

            # Apply attention to combine features
            bert_attention = torch.sigmoid(self.attention(bert_features))
            kb_attention = torch.sigmoid(self.attention(kb_features))

            # Normalize attention weights
            attention_sum = bert_attention + kb_attention
            bert_attention = bert_attention / attention_sum
            kb_attention = kb_attention / attention_sum

            # Combine features
            combined_features = bert_attention * bert_features + kb_attention * kb_features

            # Classification
            logits = self.classifier(combined_features)
        else:
            # Standard classification without knowledge base
            logits = self.classifier(text_features)

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))

        return loss, logits

## 5. Data Augmentation

In [None]:
class DataAugmenter:
    """Class for augmenting the CDR dataset with various techniques"""

    def __init__(self, ctd_path="/content/drive/MyDrive/REBL/CDR_Data/CTD/CTD_chemicals_diseases.tsv"):
        """
        Initialize the augmenter

        Args:
            ctd_path: Path to CTD database file (for distant supervision)
        """
        self.lemmatizer = WordNetLemmatizer()

        # Load CTD database if provided
        self.ctd_relations = None
        if ctd_path and os.path.exists(ctd_path):
            self.ctd_relations = self._load_ctd(ctd_path)
        else:
            # Create a dummy CTD relation set for demonstration
            print("Using dummy CTD relations for data augmentation demonstration")
            self.ctd_relations = {
                ("MESH:D008687", "MESH:D006973"),  # Example chemical-disease pair
                ("MESH:D004991", "MESH:D006973"),  # Another example pair
            }

    def _load_ctd(self, ctd_path):
        """Load CTD chemical-disease relations"""
        try:
            ctd_df = pd.read_csv(ctd_path, sep='\t', comment='#')

            # Create a set of chemical-disease pairs
            relations = set()
            for _, row in ctd_df.iterrows():
                chem_id = row.get('ChemicalID', '')
                disease_id = row.get('DiseaseID', '')
                if chem_id and disease_id:
                    relations.add((chem_id, disease_id))

            return relations
        except Exception as e:
            print(f"Error loading CTD data: {e}")
            return set()

    def synonym_replacement(self, text, n=2):
        """
        Replace n words in the text with their synonyms
        """
        words = nltk.word_tokenize(text)
        new_words = words.copy()
        random_word_list = list(set([word for word in words if len(word) > 3]))
        random.shuffle(random_word_list)

        num_replaced = 0
        for random_word in random_word_list:
            synonyms = self.get_synonyms(random_word)
            if len(synonyms) >= 1:
                synonym = random.choice(list(synonyms))
                new_words = [synonym if word == random_word else word for word in new_words]
                num_replaced += 1
            if num_replaced >= n:
                break

        return ' '.join(new_words)

    def get_synonyms(self, word):
        """Get synonyms for a word using WordNet"""
        synonyms = set()

        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ')
                if synonym != word:
                    synonyms.add(synonym)

        return synonyms

    def swap_entities(self, examples):
        """
        Swap similar chemicals/diseases between examples to create new examples

        This preserves the original context but creates new entity pairs
        """
        # Group chemicals and diseases
        chemicals = defaultdict(list)
        diseases = defaultdict(list)

        for ex in examples:
            chem_type = ex['chemical'].get('type', '')
            disease_type = ex['disease'].get('type', '')

            chemicals[chem_type].append(ex['chemical'])
            diseases[disease_type].append(ex['disease'])

        augmented_examples = []

        for ex in examples:
            # Create a copy of the original example
            new_ex = ex.copy()

            # Swap chemical with a similar one (same type)
            chem_type = ex['chemical'].get('type', '')
            if len(chemicals[chem_type]) > 1:
                # Find chemicals of same type but different from current one
                similar_chemicals = [c for c in chemicals[chem_type]
                                    if c['id'] != ex['chemical']['id']]

                if similar_chemicals:
                    # Create a new example with swapped chemical
                    new_chemical = random.choice(similar_chemicals)
                    new_ex['chemical'] = new_chemical

                    # Determine label based on CTD database (if available)
                    if self.ctd_relations:
                        new_pair = (new_chemical['id'], ex['disease']['id'])
                        new_ex['label'] = 1 if new_pair in self.ctd_relations else 0
                    else:
                        # If no CTD, assume negative example
                        new_ex['label'] = 0

                    augmented_examples.append(new_ex)

            # Create another example with swapped disease
            new_ex2 = ex.copy()
            disease_type = ex['disease'].get('type', '')
            if len(diseases[disease_type]) > 1:
                # Find diseases of same type but different from current one
                similar_diseases = [d for d in diseases[disease_type]
                                   if d['id'] != ex['disease']['id']]

                if similar_diseases:
                    # Create a new example with swapped disease
                    new_disease = random.choice(similar_diseases)
                    new_ex2['disease'] = new_disease

                    # Determine label based on CTD database (if available)
                    if self.ctd_relations:
                        new_pair = (ex['chemical']['id'], new_disease['id'])
                        new_ex2['label'] = 1 if new_pair in self.ctd_relations else 0
                    else:
                        # If no CTD, assume negative example
                        new_ex2['label'] = 0

                    augmented_examples.append(new_ex2)

        return augmented_examples

    def augment_dataset(self, examples, methods=None, factor=1.5):
        """
        Apply multiple augmentation methods to increase dataset size

        Args:
            examples: Original dataset examples
            methods: List of methods to apply, options:
                     ['synonym', 'swap']
            factor: Target size as multiple of original size

        Returns:
            Augmented dataset examples
        """
        if methods is None:
            methods = ['synonym', 'swap']

        augmented_examples = examples.copy()
        original_size = len(examples)
        target_size = int(original_size * factor)

        # Apply each method until we reach the target size
        if 'swap' in methods and len(augmented_examples) < target_size:
            print("Applying entity swapping augmentation...")
            swap_examples = self.swap_entities(examples)
            # Add new examples up to target size
            remaining_slots = target_size - len(augmented_examples)
            augmented_examples.extend(swap_examples[:remaining_slots])

        if 'synonym' in methods and len(augmented_examples) < target_size:
            print("Applying synonym replacement augmentation...")
            synonym_examples = []

            # Filter positive examples first for augmentation
            positive_examples = [ex for ex in examples if ex['label'] == 1]

            for ex in tqdm(positive_examples):
                # Create new example
                new_ex = ex.copy()

                # Only replace words that are not part of entities
                text = ex.get('text', '')
                if text:
                    # Get entity spans to avoid replacing them
                    chemical_mention = ex['chemical']['mention']
                    disease_mention = ex['disease']['mention']

                    # Split text into parts: before chemical, between entities, after disease
                    # (Simplistic approach - a real implementation would handle all possible orderings)
                    chem_idx = text.find(chemical_mention)
                    disease_idx = text.find(disease_mention)

                    if chem_idx >= 0 and disease_idx >= 0:  # Only if both entities are found
                        if chem_idx < disease_idx:
                            # Chemical appears before disease
                            before = text[:chem_idx]
                            middle = text[chem_idx + len(chemical_mention):disease_idx]
                            after = text[disease_idx + len(disease_mention):]

                            # Apply synonym replacement to each part
                            before_aug = self.synonym_replacement(before, n=1)
                            middle_aug = self.synonym_replacement(middle, n=1)
                            after_aug = self.synonym_replacement(after, n=1)

                            # Combine back with original entities
                            new_text = before_aug + chemical_mention + middle_aug + disease_mention + after_aug
                        else:
                            # Disease appears before chemical
                            before = text[:disease_idx]
                            middle = text[disease_idx + len(disease_mention):chem_idx]
                            after = text[chem_idx + len(chemical_mention):]

                            # Apply synonym replacement to each part
                            before_aug = self.synonym_replacement(before, n=1)
                            middle_aug = self.synonym_replacement(middle, n=1)
                            after_aug = self.synonym_replacement(after, n=1)

                            # Combine back with original entities
                            new_text = before_aug + disease_mention + middle_aug + chemical_mention + after_aug

                        new_ex['text'] = new_text
                        synonym_examples.append(new_ex)

            # Add new examples up to target size
            remaining_slots = target_size - len(augmented_examples)
            augmented_examples.extend(synonym_examples[:remaining_slots])

        # Print augmentation statistics
        positive_count = sum(1 for ex in augmented_examples if ex['label'] == 1)
        negative_count = len(augmented_examples) - positive_count

        print(f"Augmentation complete. Original size: {original_size}, New size: {len(augmented_examples)}")
        print(f"Class distribution: Positive: {positive_count}, Negative: {negative_count}")

        return augmented_examples

In [None]:
# Augment the training data (uncomment to use)
augmenter = DataAugmenter()
augmented_train_examples = augmenter.augment_dataset(
    train_examples,
    methods=['synonym', 'swap'],
    factor=1.5  # Increase dataset size by 50%
)
train_examples = augmented_train_examples

## 6. Training and Evaluation Functions

In [None]:
def create_balanced_sampler(dataset):
    """
    Creates a weighted sampler to oversample minority class (positive relations)
    with error handling and better label extraction
    """
    try:
        # Extract labels more carefully
        labels = []
        for i in range(len(dataset)):
            try:
                example = dataset[i]
                if isinstance(example, dict) and 'label' in example:
                    # If the example is a dictionary with a label key
                    label = example['label']
                    if isinstance(label, torch.Tensor):
                        label = label.item()
                    labels.append(label)
                elif isinstance(example, tuple) and len(example) > 0:
                    # If the example is a tuple and the last element is the label
                    label = example[-1]
                    if isinstance(label, torch.Tensor):
                        label = label.item()
                    labels.append(label)
                else:
                    # If we can't determine the label, use a default
                    print(f"Warning: Could not determine label for example {i}, using 0 as default")
                    labels.append(0)
            except Exception as e:
                print(f"Error processing example {i}: {e}")
                # Use a default label
                labels.append(0)

        # Convert to tensor
        labels_tensor = torch.tensor(labels)

        # Ensure we have at least one item of each class
        unique_labels = torch.unique(labels_tensor)
        if len(unique_labels) < 2:
            print(f"Warning: Only found one class in the dataset ({unique_labels.item()}), using uniform sampling")
            # Return uniform sampling if only one class is present
            return torch.utils.data.RandomSampler(dataset)

        # Count class instances
        class_counts = torch.bincount(labels_tensor)
        print(f"Class distribution: {class_counts.tolist()}")

        # Calculate class weights - inverse of frequency
        class_weights = 1.0 / class_counts.float()

        # For numerical stability, normalize weights
        class_weights = class_weights / class_weights.sum()

        # Assign weight to each example based on its class
        sample_weights = [class_weights[label] for label in labels]

        # Create and return the sampler
        return torch.utils.data.WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
    except Exception as e:
        print(f"Error creating balanced sampler: {e}")
        # Return a standard sampler as fallback
        return torch.utils.data.RandomSampler(dataset)

def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler=None,
                num_epochs=3, device=device, model_type="biobert"):
    """
    Generic training function for all models
    """
    # Move model to device
    model = model.to(device)

    # Metrics history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': []
    }

    # Best model tracking
    best_f1 = 0.0
    best_model_state = None

    # Training loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Training
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_dataloader, desc="Training")

        for batch in progress_bar:
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            # Forward pass depends on model type
            optimizer.zero_grad()

            if model_type == "cnn":
                input_ids = batch['input_ids']
                labels = batch['label']

                # Forward pass
                outputs = model(input_ids)
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(outputs, labels)
            elif model_type == "longformer":
                # Longformer forward pass
                loss, _ = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    global_attention_mask=batch.get('global_attention_mask'),
                    labels=batch['label']
                )
            elif model_type == "knowledge":
                # Knowledge-enhanced model forward pass
                loss, _ = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids'],
                    labels=batch['label'],
                    chemical_ids=batch.get('chemical_id'),
                    disease_ids=batch.get('disease_id')
                )
            else:  # biobert or improved_biobert
                # BioBERT forward pass
                loss, _ = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids'],
                    labels=batch['label']
                )

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Optimization
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            train_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        avg_train_loss = train_loss / len(train_dataloader)
        history['train_loss'].append(avg_train_loss)

        # Evaluation
        print("Evaluating...")
        metrics = evaluate_model(model, val_dataloader, device, model_type)

        # Store metrics
        history['val_loss'].append(metrics['loss'])
        history['val_precision'].append(metrics['precision'])
        history['val_recall'].append(metrics['recall'])
        history['val_f1'].append(metrics['f1'])

        # Save best model
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            best_model_state = model.state_dict().copy()

        # Print metrics
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Val Loss: {metrics['loss']:.4f}")
        print(f"Val Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}")

    # Load best model if available
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with F1: {best_f1:.4f}")

    return model, history

def evaluate_model(model, dataloader, device, model_type="biobert", threshold=0.5):
    """
    Evaluate model on a dataset
    """
    model.eval()

    all_preds = []
    all_labels = []
    total_loss = 0.0

    progress_bar = tqdm(dataloader, desc="Evaluating")

    with torch.no_grad():
        for batch in progress_bar:
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            labels = batch['label']

            # Forward pass depends on model type
            if model_type == "cnn":
                input_ids = batch['input_ids']

                # Forward pass
                outputs = model(input_ids)
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(outputs, labels)

                # Get predictions
                preds = torch.argmax(outputs, dim=1)
            elif model_type == "longformer":
                # Longformer forward pass
                loss, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    global_attention_mask=batch.get('global_attention_mask'),
                    labels=labels
                )

                # Get predictions
                preds = torch.argmax(logits, dim=1)
            elif model_type == "knowledge":
                # Knowledge-enhanced model forward pass
                loss, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids'],
                    labels=labels,
                    chemical_ids=batch.get('chemical_id'),
                    disease_ids=batch.get('disease_id')
                )

                # Get predictions
                preds = torch.argmax(logits, dim=1)
            else:  # biobert or improved_biobert
                # BioBERT forward pass
                loss, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids'],
                    labels=labels
                )

                # For improved_biobert, apply custom threshold
                if model_type == "improved_biobert":
                    probs = F.softmax(logits, dim=1)[:, 1]
                    preds = (probs >= threshold).int()
                else:
                    # Standard prediction
                    preds = torch.argmax(logits, dim=1)

            total_loss += loss.item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='binary', zero_division=0
    )

    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Return metrics
    metrics = {
        'loss': total_loss / len(dataloader),
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels
    }

    return metrics

def tune_threshold_for_recall(model, validation_dataloader, device, model_type="improved_biobert",
                             thresholds=None):
    """
    Finds the optimal confidence threshold to maximize recall while maintaining acceptable precision
    """
    if thresholds is None:
        thresholds = np.arange(0.1, 0.9, 0.05)

    model.eval()
    all_logits = []
    all_labels = []
    # Collect all predictions and labels
    with torch.no_grad():
        for batch in tqdm(validation_dataloader, desc="Collecting predictions"):
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            labels = batch['label']

            # Forward pass depends on model type
            if model_type == "cnn":
                input_ids = batch['input_ids']
                outputs = model(input_ids)
                probs = F.softmax(outputs, dim=1)
            elif model_type == "longformer":
                _, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    global_attention_mask=batch.get('global_attention_mask')
                )
                probs = F.softmax(logits, dim=1)
            elif model_type == "knowledge":
                _, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids'],
                    chemical_ids=batch.get('chemical_id'),
                    disease_ids=batch.get('disease_id')
                )
                probs = F.softmax(logits, dim=1)
            else:  # biobert or improved_biobert
                _, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids']
                )
                probs = F.softmax(logits, dim=1)

            # Store probabilities for positive class
            all_logits.extend(probs[:, 1].cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Try different thresholds
    results = []
    for threshold in thresholds:
        preds = [1 if prob >= threshold else 0 for prob in all_logits]
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, preds, average='binary', zero_division=0
        )

        # Calculate F2 score to prioritize recall over precision
        f2 = (5 * precision * recall) / (4 * precision + recall) if (precision + recall) > 0 else 0

        results.append({
            'threshold': threshold,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'f2': f2
        })

    # Find the optimal threshold
    # Sort by F2 to prioritize recall
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('f2', ascending=False)

    # Get top results
    best_result = results_df.iloc[0]
    print(f"Best threshold: {best_result['threshold']:.2f}")
    print(f"Precision: {best_result['precision']:.4f}, Recall: {best_result['recall']:.4f}, F1: {best_result['f1']:.4f}, F2: {best_result['f2']:.4f}")

    return best_result['threshold'], results


def visualize_pr_curve(results):
    """
    Visualize precision-recall curve from threshold tuning results
    """
    plt.figure(figsize=(10, 6))

    # Sort results by threshold
    df = pd.DataFrame(results).sort_values('threshold')

    # Plot precision and recall
    plt.plot(df['threshold'], df['precision'], 'b-', label='Precision')
    plt.plot(df['threshold'], df['recall'], 'r-', label='Recall')
    plt.plot(df['threshold'], df['f1'], 'g-', label='F1 Score')
    plt.plot(df['threshold'], df['f2'], 'y-', label='F2 Score')

    # Add threshold with best F2 score
    best_row = df.loc[df['f2'].idxmax()]
    plt.axvline(x=best_row['threshold'], color='k', linestyle='--', label=f'Best F2 Threshold: {best_row["threshold"]:.2f}')

    # Add vertical line at 0.5 threshold for comparison
    plt.axvline(x=0.5, color='gray', linestyle=':', label='Default Threshold: 0.5')

    # Labels and formatting
    plt.xlabel('Classification Threshold')
    plt.ylabel('Score')
    plt.title('Precision, Recall, F1, and F2 vs. Threshold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    return plt


def visualize_confusion_matrix(cm, title="Confusion Matrix"):
    """Visualize confusion matrix"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['No Relation', 'CID Relation'],
                yticklabels=['No Relation', 'CID Relation'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(title)
    plt.tight_layout()

    return plt


def compare_models(results_dict):
    """
    Compare performance of multiple models

    Args:
        results_dict: Dictionary mapping model names to their metrics
    """
    # Extract metrics for comparison
    model_names = list(results_dict.keys())
    precision_values = [results_dict[model]['precision'] for model in model_names]
    recall_values = [results_dict[model]['recall'] for model in model_names]
    f1_values = [results_dict[model]['f1'] for model in model_names]

    # Create DataFrame for easier plotting
    df = pd.DataFrame({
        'Model': model_names,
        'Precision': precision_values,
        'Recall': recall_values,
        'F1 Score': f1_values
    })

    # Reshape for plotting
    df_plot = df.melt(id_vars=['Model'], var_name='Metric', value_name='Value')

    # Plot
    plt.figure(figsize=(12, 8))
    ax = sns.barplot(x='Model', y='Value', hue='Metric', data=df_plot)

    # Add value labels on top of bars
    for i, container in enumerate(ax.containers):
        ax.bar_label(container, fmt='%.2f', padding=3)

    plt.title('Model Performance Comparison')
    plt.ylabel('Score')
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0, 1.0)
    plt.legend(title='Metric')
    plt.tight_layout()

    return plt


def analyze_errors(model, dataloader, examples, device, model_type="biobert"):
    """
    Analyze specific error cases made by the model

    Args:
        model: Model to analyze
        dataloader: Data loader
        examples: Original examples list
        device: Device to use
        model_type: Type of model

    Returns:
        Dictionary with error examples
    """
    model.eval()

    errors = {
        'false_positives': [],
        'false_negatives': []
    }

    # Keep track of example index
    example_idx = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Analyzing errors")):
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            labels = batch['label']
            batch_size = labels.size(0)

            # Forward pass depends on model type
            if model_type == "cnn":
                input_ids = batch['input_ids']
                outputs = model(input_ids)
                preds = torch.argmax(outputs, dim=1)
            elif model_type == "longformer":
                _, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    global_attention_mask=batch.get('global_attention_mask')
                )
                preds = torch.argmax(logits, dim=1)
            elif model_type == "knowledge":
                _, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids'],
                    chemical_ids=batch.get('chemical_id'),
                    disease_ids=batch.get('disease_id')
                )
                preds = torch.argmax(logits, dim=1)
            else:  # biobert or improved_biobert
                _, logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    token_type_ids=batch['token_type_ids']
                )
                preds = torch.argmax(logits, dim=1)

            # Find errors in this batch
            for i in range(batch_size):
                pred = preds[i].item()
                label = labels[i].item()

                # Calculate current example index
                current_idx = example_idx + i

                # Check if current_idx is within bounds of examples
                if current_idx < len(examples):
                    example = examples[current_idx]

                    if pred == 1 and label == 0:
                        # False positive
                        errors['false_positives'].append({
                            'example': example,
                            'chemical': example['chemical']['mention'],
                            'disease': example['disease']['mention'],
                            'text': example['text'][:200] + '...' if len(example['text']) > 200 else example['text']
                        })
                    elif pred == 0 and label == 1:
                        # False negative
                        errors['false_negatives'].append({
                            'example': example,
                            'chemical': example['chemical']['mention'],
                            'disease': example['disease']['mention'],
                            'text': example['text'][:200] + '...' if len(example['text']) > 200 else example['text']
                        })

            # Update example index for next batch
            example_idx += batch_size

    # Print error summary
    print(f"Total false positives: {len(errors['false_positives'])}")
    print(f"Total false negatives: {len(errors['false_negatives'])}")

    return errors


## 7. Experiment Execution

### 7.1 Initialize Tokenizers and Models

In [None]:
# Initialize tokenizers
def initialize_tokenizers():
    """Initialize tokenizers for all models"""
    print("Initializing tokenizers...")

    tokenizers = {}

    # Basic tokenizer for baseline
    tokenizers['cnn'] = BertTokenizer.from_pretrained("bert-base-uncased")

    # BioBERT tokenizer
    try:
        tokenizers['biobert'] = BertTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
        print("Loaded BioBERT tokenizer successfully")
    except Exception as e:
        print(f"Error loading BioBERT tokenizer: {e}")
        print("Falling back to BERT base tokenizer for BioBERT")
        tokenizers['biobert'] = BertTokenizer.from_pretrained("bert-base-uncased")

    # Same tokenizer for enhanced BioBERT
    tokenizers['improved_biobert'] = tokenizers['biobert']

    # Longformer tokenizer
    try:
        tokenizers['longformer'] = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
        print("Loaded Longformer tokenizer successfully")
    except Exception as e:
        print(f"Error loading Longformer tokenizer: {e}")
        print("Falling back to BERT tokenizer for demonstration")
        tokenizers['longformer'] = tokenizers['biobert']

    # Add special tokens for entity markers to all tokenizers
    special_tokens = {'additional_special_tokens': ['[CHEM]', '[/CHEM]', '[DISE]', '[/DISE]']}

    for name, tokenizer in tokenizers.items():
        num_added = tokenizer.add_special_tokens(special_tokens)
        print(f"Added {num_added} special tokens to {name} tokenizer")

    return tokenizers


def initialize_models(tokenizers):
    """Initialize all model variants"""
    print("Initializing models...")

    models = {}

    # CNN Baseline
    vocab_size = len(tokenizers['cnn'].vocab)
    models['cnn'] = CNNBaseline(vocab_size=vocab_size)
    print("Initialized CNN baseline model")

    # BioBERT
    try:
        models['biobert'] = BioBERTForRelationExtraction(pretrained_model="dmis-lab/biobert-v1.1")
        print("Initialized BioBERT model with pre-trained weights")
    except Exception as e:
        print(f"Error loading BioBERT model: {e}")
        print("Initializing with BERT base model for demonstration")
        models['biobert'] = BioBERTForRelationExtraction(pretrained_model="bert-base-uncased")

    # Resize token embeddings for the special tokens
    models['biobert'].bert.resize_token_embeddings(len(tokenizers['biobert']))

    # Enhanced BioBERT with recall optimization
    # Use BioBERT weights but with our optimized structure
    try:
        class_weights = torch.tensor([1.0, 3.0])  # Weight positive class higher
        models['improved_biobert'] = ImprovedBioBERTForRelationExtraction(
            pretrained_model="dmis-lab/biobert-v1.1",
            focal_loss_gamma=2.0,
            class_weights=class_weights
        )
        print("Initialized enhanced BioBERT model with focal loss")
    except Exception as e:
        print(f"Error loading enhanced BioBERT model: {e}")
        print("Initializing with BERT base model for demonstration")
        models['improved_biobert'] = ImprovedBioBERTForRelationExtraction(
            pretrained_model="bert-base-uncased",
            focal_loss_gamma=2.0,
            class_weights=class_weights
        )

    # Resize token embeddings for the special tokens
    models['improved_biobert'].bert.resize_token_embeddings(len(tokenizers['biobert']))

    # Longformer Model
    try:
        models['longformer'] = LongformerForRelationExtraction(pretrained_model="allenai/longformer-base-4096")
        print("Initialized Longformer model with pre-trained weights")

        # Resize token embeddings for the special tokens
        models['longformer'].longformer.resize_token_embeddings(len(tokenizers['longformer']))
    except Exception as e:
        print(f"Error with Longformer model: {e}")
        print("Skipping Longformer model initialization")

    # Knowledge Enhanced BioBERT
    # Initialize the knowledge base
    kb = CTDKnowledgeBase(dummy_mode=True)  # Using dummy mode for demonstration

    try:
        models['knowledge'] = KnowledgeEnhancedBioBERT(
            pretrained_model="dmis-lab/biobert-v1.1",
            knowledge_base=kb,
            knowledge_dim=128
        )
        print("Initialized Knowledge-Enhanced BioBERT model")

        # Resize token embeddings for the special tokens
        models['knowledge'].bert.resize_token_embeddings(len(tokenizers['biobert']))
    except Exception as e:
        print(f"Error with Knowledge-Enhanced model: {e}")
        print("Skipping Knowledge-Enhanced model initialization")

    return models

### 7.2 Create Datasets and Dataloaders

In [None]:
def create_datasets_and_dataloaders(tokenizers, train_examples, dev_examples, test_examples, batch_sizes=None):
    """Create datasets and dataloaders for all models"""
    print("Creating datasets and dataloaders...")

    if batch_sizes is None:
        batch_sizes = {
            'cnn': 16,
            'biobert': 16,
            'improved_biobert': 16,
            'longformer': 4,  # Smaller batch size for Longformer due to memory constraints
            'knowledge': 16
        }

    datasets = {
        'train': {},
        'dev': {},
        'test': {}
    }

    dataloaders = {
        'train': {},
        'dev': {},
        'test': {}
    }

    # Verify examples are not empty
    if not train_examples:
        print("WARNING: train_examples is empty!")
    if not dev_examples:
        print("WARNING: dev_examples is empty!")
    if not test_examples:
        print("WARNING: test_examples is empty!")

    # CNN Baseline datasets
    print("Creating CNN datasets...")
    try:
        datasets['train']['cnn'] = CDRBaselineDataset(train_examples, tokenizers['cnn'])
        datasets['dev']['cnn'] = CDRBaselineDataset(dev_examples, tokenizers['cnn'])
        datasets['test']['cnn'] = CDRBaselineDataset(test_examples, tokenizers['cnn'])
        print(f"CNN datasets created successfully. Sizes - Train: {len(datasets['train']['cnn'])}, Dev: {len(datasets['dev']['cnn'])}, Test: {len(datasets['test']['cnn'])}")
    except Exception as e:
        print(f"Error creating CNN datasets: {e}")

    # BioBERT datasets
    print("Creating BioBERT datasets...")
    try:
        datasets['train']['biobert'] = CDRBioBERTDataset(train_examples, tokenizers['biobert'])
        datasets['dev']['biobert'] = CDRBioBERTDataset(dev_examples, tokenizers['biobert'])
        datasets['test']['biobert'] = CDRBioBERTDataset(test_examples, tokenizers['biobert'])
        print(f"BioBERT datasets created successfully. Sizes - Train: {len(datasets['train']['biobert'])}, Dev: {len(datasets['dev']['biobert'])}, Test: {len(datasets['test']['biobert'])}")
    except Exception as e:
        print(f"Error creating BioBERT datasets: {e}")

    # Enhanced BioBERT datasets (same as BioBERT)
    print("Creating enhanced BioBERT datasets...")
    try:
        # Rather than referencing the same objects, create new datasets to avoid potential issues
        datasets['train']['improved_biobert'] = CDRBioBERTDataset(train_examples, tokenizers['improved_biobert'])
        datasets['dev']['improved_biobert'] = CDRBioBERTDataset(dev_examples, tokenizers['improved_biobert'])
        datasets['test']['improved_biobert'] = CDRBioBERTDataset(test_examples, tokenizers['improved_biobert'])
        print(f"Enhanced BioBERT datasets created successfully. Sizes - Train: {len(datasets['train']['improved_biobert'])}")
    except Exception as e:
        print(f"Error creating enhanced BioBERT datasets: {e}")

    # Longformer datasets
    if 'longformer' in tokenizers:
        print("Creating Longformer datasets...")
        try:
            datasets['train']['longformer'] = CDRLongformerDataset(train_examples, tokenizers['longformer'])
            datasets['dev']['longformer'] = CDRLongformerDataset(dev_examples, tokenizers['longformer'])
            datasets['test']['longformer'] = CDRLongformerDataset(test_examples, tokenizers['longformer'])
            print(f"Longformer datasets created successfully. Sizes - Train: {len(datasets['train']['longformer'])}")
        except Exception as e:
            print(f"Error creating Longformer datasets: {e}")

    # Knowledge-Enhanced BioBERT datasets
    print("Creating Knowledge-Enhanced BioBERT datasets...")
    try:
        # Create new datasets rather than referencing BioBERT datasets
        datasets['train']['knowledge'] = CDRBioBERTDataset(train_examples, tokenizers['biobert'])
        datasets['dev']['knowledge'] = CDRBioBERTDataset(dev_examples, tokenizers['biobert'])
        datasets['test']['knowledge'] = CDRBioBERTDataset(test_examples, tokenizers['biobert'])
        print(f"Knowledge-Enhanced BioBERT datasets created successfully.")
    except Exception as e:
        print(f"Error creating Knowledge-Enhanced BioBERT datasets: {e}")

    # Create dataloaders
    print("Creating dataloaders...")
    for split in ['train', 'dev', 'test']:
        for model_type, dataset in datasets[split].items():
            try:
                # Make sure dataset is not empty
                if len(dataset) == 0:
                    print(f"WARNING: {split} dataset for {model_type} is empty, skipping dataloader creation")
                    continue

                # Make sure batch size is appropriate
                effective_batch_size = min(batch_sizes[model_type], len(dataset))
                if effective_batch_size != batch_sizes[model_type]:
                    print(f"WARNING: Reducing batch size for {model_type} {split} from {batch_sizes[model_type]} to {effective_batch_size} due to small dataset size")

                if split == 'train' and model_type == 'improved_biobert':
                    # Use balanced sampling for improved_biobert training
                    try:
                        sampler = create_balanced_sampler(dataset)
                        dataloaders[split][model_type] = DataLoader(
                            dataset,
                            batch_size=effective_batch_size,
                            sampler=sampler
                        )
                        print(f"Created {split} dataloader for {model_type} with balanced sampling")
                    except Exception as e:
                        print(f"Error creating balanced sampler for {model_type}: {e}")
                        # Fall back to standard dataloader without sampling
                        dataloaders[split][model_type] = DataLoader(
                            dataset,
                            batch_size=effective_batch_size,
                            shuffle=(split == 'train')
                        )
                        print(f"Created {split} dataloader for {model_type} with standard configuration (fallback)")
                else:
                    # Standard dataloader for others
                    shuffle = (split == 'train')  # Only shuffle training data
                    dataloaders[split][model_type] = DataLoader(
                        dataset,
                        batch_size=effective_batch_size,
                        shuffle=shuffle
                    )
                    print(f"Created {split} dataloader for {model_type} with batch size {effective_batch_size}")
            except Exception as e:
                print(f"Error creating dataloader for {model_type} {split}: {e}")

    return datasets, dataloaders


### 7.3 Training Function

In [None]:
def train_all_models(models, dataloaders, num_epochs=3, device=device, save_dir="saved_models"):
    """Train all models and save the results"""
    os.makedirs(save_dir, exist_ok=True)

    results = {}
    histories = {}

    for model_type, model in models.items():
        print(f"\n{'='*50}")
        print(f"Training {model_type} model...")
        print(f"{'='*50}\n")

        # Get the dataloaders
        train_dataloader = dataloaders['train'].get(model_type)
        dev_dataloader = dataloaders['dev'].get(model_type)

        if train_dataloader is None or dev_dataloader is None:
            print(f"Skipping {model_type} - dataloaders not available")
            continue

        # Set up optimizer and scheduler based on model type
        if model_type == 'cnn':
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            scheduler = None
        elif model_type == 'longformer':
            optimizer = AdamW(model.parameters(), lr=3e-5, eps=1e-8)
            total_steps = len(train_dataloader) * num_epochs
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=0,
                num_training_steps=total_steps
            )
        else:  # biobert, improved_biobert, knowledge
            optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
            total_steps = len(train_dataloader) * num_epochs
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=0,
                num_training_steps=total_steps
            )

        # Train the model
        trained_model, history = train_model(
            model=model,
            train_dataloader=train_dataloader,
            val_dataloader=dev_dataloader,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=num_epochs,
            device=device,
            model_type=model_type
        )

        # Save the model and history
        try:
            torch.save(trained_model.state_dict(), os.path.join(save_dir, f"{model_type}_model.pt"))
            with open(os.path.join(save_dir, f"{model_type}_history.pkl"), 'wb') as f:
                pickle.dump(history, f)
            print(f"Saved {model_type} model and history")
        except Exception as e:
            print(f"Error saving model: {e}")

        # Evaluate on validation set
        print(f"Evaluating {model_type} model on validation set...")
        metrics = evaluate_model(
            model=trained_model,
            dataloader=dev_dataloader,
            device=device,
            model_type=model_type
        )

        # Store results
        results[model_type] = metrics
        histories[model_type] = history

        # Print metrics
        print(f"{model_type} Validation Results:")
        print(f"Precision: {metrics['precision']:.4f}")
        print(f"Recall: {metrics['recall']:.4f}")
        print(f"F1 Score: {metrics['f1']:.4f}")
        print(f"Confusion Matrix:\n{metrics['confusion_matrix']}")

        # For improved_biobert, tune threshold for better recall
        if model_type == 'improved_biobert':
            print("\nTuning threshold for improved recall...")
            best_threshold, threshold_results = tune_threshold_for_recall(
                model=trained_model,
                validation_dataloader=dev_dataloader,
                device=device,
                model_type=model_type
            )

            # Visualize precision-recall curve
            pr_plot = visualize_pr_curve(threshold_results)
            pr_plot.savefig(os.path.join(save_dir, "pr_curve.png"))
            pr_plot.close()

            # Re-evaluate with optimized threshold
            print(f"Re-evaluating with threshold={best_threshold:.2f}")
            metrics = evaluate_model(
                model=trained_model,
                dataloader=dev_dataloader,
                device=device,
                model_type=model_type,
                threshold=best_threshold
            )

            # Store updated results
            results[f"{model_type}_optimized"] = metrics

            print(f"{model_type} (Optimized Threshold) Validation Results:")
            print(f"Precision: {metrics['precision']:.4f}")
            print(f"Recall: {metrics['recall']:.4f}")
            print(f"F1 Score: {metrics['f1']:.4f}")
            print(f"Confusion Matrix:\n{metrics['confusion_matrix']}")

    return results, histories

### 7.4 Ablation Study

In [None]:
def run_ablation_study(models, dataloaders, device=device, save_dir="ablation_results"):
    """Run ablation study to understand the contribution of each enhancement"""
    os.makedirs(save_dir, exist_ok=True)

    print("\nRunning Ablation Study...")

    # Define model variants to compare
    variants = [
        'biobert',                  # Original BioBERT
        'improved_biobert',         # With recall optimization
        'improved_biobert_optimized',  # With threshold tuning
        'knowledge',                # With knowledge base
        'longformer'                # With document-level context
    ]

    # Evaluate each variant
    ablation_results = {}

    for variant in variants:
        if variant in models and variant in dataloaders['test']:
            print(f"\nEvaluating {variant}...")

            # Special case for optimized threshold
            if variant == 'improved_biobert_optimized':
                # Use the model from improved_biobert but with optimized threshold
                model = models['improved_biobert']
                threshold = 0.3  # Example optimized threshold, replace with actual value

                metrics = evaluate_model(
                    model=model,
                    dataloader=dataloaders['test']['improved_biobert'],
                    device=device,
                    model_type='improved_biobert',
                    threshold=threshold
                )
            else:
                # Standard evaluation
                metrics = evaluate_model(
                    model=models[variant],
                    dataloader=dataloaders['test'][variant],
                    device=device,
                    model_type=variant
                )

            # Store results
            ablation_results[variant] = metrics

            # Print metrics
            print(f"{variant} Test Results:")
            print(f"Precision: {metrics['precision']:.4f}")
            print(f"Recall: {metrics['recall']:.4f}")
            print(f"F1 Score: {metrics['f1']:.4f}")

    # Compare models visually
    if len(ablation_results) > 1:
        comparison_plot = compare_models(ablation_results)
        comparison_plot.savefig(os.path.join(save_dir, "model_comparison.png"))
        comparison_plot.close()

    return ablation_results

### 7.5 Error Analysis

In [None]:
def run_error_analysis(models, dataloaders, examples, device=device, save_dir="error_analysis"):
    """Analyze errors made by different models"""
    os.makedirs(save_dir, exist_ok=True)

    print("\nRunning Error Analysis...")

    error_results = {}

    # Choose which models to analyze
    models_to_analyze = ['biobert', 'improved_biobert_optimized']

    for model_type in models_to_analyze:
        if model_type == 'improved_biobert_optimized':
            # Use improved_biobert model with optimized threshold
            model = models['improved_biobert']
            dataloader = dataloaders['test']['improved_biobert']
            actual_model_type = 'improved_biobert'
        else:
            model = models.get(model_type)
            dataloader = dataloaders['test'].get(model_type)
            actual_model_type = model_type

        if model is None or dataloader is None:
            print(f"Skipping {model_type} - model or dataloader not available")
            continue

        print(f"\nAnalyzing errors for {model_type}...")
        errors = analyze_errors(
            model=model,
            dataloader=dataloader,
            examples=examples,
            device=device,
            model_type=actual_model_type
        )

        error_results[model_type] = errors

        # Save error examples
        with open(os.path.join(save_dir, f"{model_type}_errors.pkl"), 'wb') as f:
            pickle.dump(errors, f)

        # Print some example errors
        print("\nFalse Positive Examples:")
        for i, error in enumerate(errors['false_positives'][:3]):
            print(f"\nFP Example {i+1}:")
            print(f"Chemical: {error['chemical']}")
            print(f"Disease: {error['disease']}")
            print(f"Text: {error['text']}")

        print("\nFalse Negative Examples:")
        for i, error in enumerate(errors['false_negatives'][:3]):
            print(f"\nFN Example {i+1}:")
            print(f"Chemical: {error['chemical']}")
            print(f"Disease: {error['disease']}")
            print(f"Text: {error['text']}")

    # Compare error distributions
    if len(error_results) > 1:
        fp_counts = [len(errors['false_positives']) for errors in error_results.values()]
        fn_counts = [len(errors['false_negatives']) for errors in error_results.values()]

        plt.figure(figsize=(10, 6))
        x = np.arange(len(error_results))
        width = 0.35

        plt.bar(x - width/2, fp_counts, width, label='False Positives')
        plt.bar(x + width/2, fn_counts, width, label='False Negatives')

        plt.xlabel('Model')
        plt.ylabel('Count')
        plt.title('Error Distribution by Model')
        plt.xticks(x, list(error_results.keys()))
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, "error_distribution.png"))
        plt.close()

    return error_results

## 8. Main Execution

In [None]:
def main():
    """Main function to run the entire experiment pipeline"""
    # 1. Initialize tokenizers
    tokenizers = initialize_tokenizers()

    # 2. Initialize models
    models = initialize_models(tokenizers)

    # 3. Create datasets and dataloaders
    datasets, dataloaders = create_datasets_and_dataloaders(
        tokenizers,
        train_examples,
        dev_examples,
        test_examples
    )

    # 4. Train models
    # Uncomment to train models (this can take a long time)
    results, histories = train_all_models(models, dataloaders, num_epochs=3)

    # 5. Load pre-trained models if available
    load_pretrained = False # Set to True to load pre-trained models
    if load_pretrained:
        print("\nLoading pre-trained models...")
        for model_type, model in models.items():
            model_path = f"saved_models/{model_type}_model.pt"
            if os.path.exists(model_path):
                try:
                    model.load_state_dict(torch.load(model_path, map_location=device))
                    print(f"Loaded pre-trained {model_type} model")
                except Exception as e:
                    print(f"Error loading {model_type} model: {e}")
            else:
                print(f"No pre-trained model found for {model_type}")

    # 6. Run ablation study
    ablation_results = run_ablation_study(models, dataloaders)

    # 7. Run error analysis
    error_results = run_error_analysis(models, dataloaders, test_examples)

    # 8. Summary
    print("\n" + "="*50)
    print("Experiment Summary")
    print("="*50)

    for model_type, metrics in ablation_results.items():
        print(f"\n{model_type}:")
        print(f"Precision: {metrics['precision']:.4f}")
        print(f"Recall: {metrics['recall']:.4f}")
        print(f"F1 Score: {metrics['f1']:.4f}")

    print("\nExperiment pipeline completed!")

    return ablation_results, error_results

ablation_results, error_results = main()
