In [None]:
!wget -O data.zip https://github.com/MMartinelli-hub/ATA_Tutorship/raw/refs/heads/main/gutbrainie/data/GutBrainIE_Full_Collection_2025.zip
!unzip data.zip "*"
!mkdir data
!mv Annotations/ ./data
!mv Articles/ ./data
!mv Test_Data/ ./data
!rm -rf data.zip

## Setup and Imports

In [None]:
import json
import os
import random
import numpy as np
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModel,
    TrainingArguments,
    Trainer
)
from torch.utils.data import Dataset
from tqdm import tqdm
from collections import Counter

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("Setup complete")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Define Relation Labels and Configuration

In [None]:
# Define legal relation predicates
RELATION_LABELS = [
    "no relation",  # For negative samples
    "administered",
    "affect",
    "change abundance",
    "change effect",
    "change expression",
    "compared to",
    "impact",
    "influence",
    "interact",
    "is a",
    "is linked to",
    "located in",
    "part of",
    "produced by",
    "strike",
    "target",
    "used by"
]

label2id = {label: idx for idx, label in enumerate(RELATION_LABELS)}
id2label = {idx: label for idx, label in enumerate(RELATION_LABELS)}

print(f"Total relation labels: {len(RELATION_LABELS)}")
print(f"Labels: {RELATION_LABELS}")

# Define legal entity type relations (subject_label, predicate, object_label)
# Order matters: relation is from subject to object
LEGAL_RELATIONS = [
    ("DDF", "affect", "DDF"),
    ("microbiome", "is linked to", "DDF"),
    ("DDF", "target", "human"),
    ("drug", "change effect", "DDF"),
    ("DDF", "is a", "DDF"),
    ("microbiome", "located in", "human"),
    ("chemical", "influence", "DDF"),
    ("dietary supplement", "influence", "DDF"),
    ("DDF", "target", "animal"),
    ("chemical", "impact", "microbiome"),
    ("anatomical location", "located in", "animal"),
    ("microbiome", "located in", "animal"),
    ("chemical", "located in", "anatomical location"),
    ("bacteria", "part of", "microbiome"),
    ("DDF", "strike", "anatomical location"),
    ("drug", "administered", "animal"),
    ("bacteria", "influence", "DDF"),
    ("drug", "impact", "microbiome"),
    ("DDF", "change abundance", "microbiome"),
    ("microbiome", "located in", "anatomical location"),
    ("microbiome", "used by", "biomedical technique"),
    ("chemical", "produced by", "microbiome"),
    ("dietary supplement", "impact", "microbiome"),
    ("bacteria", "located in", "animal"),
    ("animal", "used by", "biomedical technique"),
    ("chemical", "impact", "bacteria"),
    ("chemical", "located in", "animal"),
    ("food", "impact", "bacteria"),
    ("microbiome", "compared to", "microbiome"),
    ("human", "used by", "biomedical technique"),
    ("bacteria", "change expression", "gene"),
    ("chemical", "located in", "human"),
    ("drug", "interact", "chemical"),
    ("food", "administered", "human"),
    ("DDF", "change abundance", "bacteria"),
    ("chemical", "interact", "chemical"),
    ("chemical", "part of", "chemical"),
    ("dietary supplement", "impact", "bacteria"),
    ("DDF", "interact", "chemical"),
    ("food", "impact", "microbiome"),
    ("food", "influence", "DDF"),
    ("bacteria", "located in", "human"),
    ("dietary supplement", "administered", "human"),
    ("bacteria", "interact", "chemical"),
    ("drug", "change expression", "gene"),
    ("drug", "impact", "bacteria"),
    ("drug", "administered", "human"),
    ("anatomical location", "located in", "human"),
    ("dietary supplement", "change expression", "gene"),
    ("chemical", "change expression", "gene"),
    ("bacteria", "interact", "bacteria"),
    ("drug", "interact", "drug"),
    ("microbiome", "change expression", "gene"),
    ("bacteria", "interact", "drug"),
    ("food", "change expression", "gene")
]

# Create lookup structures for legal relations
# Map (subject_label, object_label) -> set of predicates
legal_pairs = {}
for subj_label, pred, obj_label in LEGAL_RELATIONS:
    key = (subj_label, obj_label)
    if key not in legal_pairs:
        legal_pairs[key] = set()
    legal_pairs[key].add(pred)

print(f"\nTotal legal relation patterns: {len(LEGAL_RELATIONS)}")
print(f"Total unique entity type pairs: {len(legal_pairs)}")

# Configuration
model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"  # BioBERT for biomedical text
output_model_dir = "models/bert_biomedbert_re"
max_length = 512
NEGATIVE_SAMPLE_MULTIPLIER = 10  # Number of negative samples per positive sample

print(f"\nModel: {model_name}")
print(f"Output directory: {output_model_dir}")
print(f"Negative sample multiplier: {NEGATIVE_SAMPLE_MULTIPLIER}")

## BERT Model with Entity Markers

In [None]:
class BertForREWithEntityMarkers(nn.Module):
    """
    BERT model for Relation Extraction with entity marker tokens.
    
    The model extracts hidden states at [E1] and [E2] token positions,
    concatenates them, and passes through a classification head.
    """
    
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        
        # Classification head: concatenated entity representations -> labels
        hidden_size = self.bert.config.hidden_size
        self.classifier = nn.Linear(hidden_size * 2, num_labels)
        
        self.num_labels = num_labels
    
    def forward(self, input_ids, attention_mask, e1_mask, e2_mask, labels=None):
        """
        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            e1_mask: Mask for [E1] token position [batch_size, seq_len]
            e2_mask: Mask for [E2] token position [batch_size, seq_len]
            labels: Ground truth labels [batch_size]
        """
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        # Extract hidden states at [E1] and [E2] positions
        # e1_mask and e2_mask are one-hot vectors indicating token positions
        e1_h = torch.bmm(e1_mask.unsqueeze(1).float(), sequence_output).squeeze(1)  # [batch_size, hidden_size]
        e2_h = torch.bmm(e2_mask.unsqueeze(1).float(), sequence_output).squeeze(1)  # [batch_size, hidden_size]
        
        # Concatenate entity representations
        concat_h = torch.cat([e1_h, e2_h], dim=-1)  # [batch_size, hidden_size * 2]
        concat_h = self.dropout(concat_h)
        
        # Classification
        logits = self.classifier(concat_h)  # [batch_size, num_labels]
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits
        }


print("BERT RE model class defined")

## Data Loading Functions

In [None]:
def load_re_data(file_paths):
    """Load relation extraction data from multiple JSON files."""
    all_data = {}
    
    for file_path in file_paths:
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            all_data.update(data)
            print(f"Loaded {len(data)} documents from {os.path.basename(file_path)}")
        else:
            print(f"Warning: {file_path} not found")
    
    return all_data


print("Data loading function defined")

## Load Training and Dev Data

In [None]:
# Load training data from three quality levels
train_files = [
    "../data/Annotations/Train/gold_quality/json_format/train_gold.json",
    "../data/Annotations/Train/platinum_quality/json_format/train_platinum.json",
    "../data/Annotations/Train/silver_quality/json_format/train_silver.json"
]

train_data = load_re_data(train_files)
print(f"\nTotal training documents: {len(train_data)}")

In [None]:
# Load dev data
dev_data = load_re_data(["../data/Annotations/Dev/json_format/dev.json"])
print(f"Total dev documents: {len(dev_data)}")

## Prepare Relation Extraction Examples

For each document:
1. Extract positive relation examples from annotations
2. Generate negative examples by pairing entities that are NOT related
3. Apply the negative sample multiplier to balance the dataset

In [None]:
def create_full_text_with_offsets(title, abstract):
    """
    Create full text by concatenating title and abstract.
    Returns full text and offset for abstract entities.
    """
    full_text = f"{title} {abstract}"
    abstract_offset = len(title) + 1
    return full_text, abstract_offset


def adjust_entity_positions(entity, abstract_offset):
    """
    Adjust entity character positions to account for title + abstract concatenation.
    """
    if entity['location'] == 'abstract':
        return {
            'start_idx': entity['start_idx'] + abstract_offset,
            'end_idx': entity['end_idx'] + abstract_offset,
            'text_span': entity['text_span'],
            'label': entity['label']
        }
    else:
        return {
            'start_idx': entity['start_idx'],
            'end_idx': entity['end_idx'],
            'text_span': entity['text_span'],
            'label': entity['label']
        }


def prepare_re_examples(data, negative_multiplier=1, legal_pairs=None):
    """
    Prepare relation extraction examples with positive and negative samples.
    Only considers entity pairs that match legal relation patterns.
    
    Args:
        data: Dictionary of documents with entities and relations
        negative_multiplier: Number of negative samples per positive sample
        legal_pairs: Dictionary mapping (subject_label, object_label) -> set of predicates
    
    Returns:
        List of examples with text, subject, object, and predicate
    """
    examples = []
    
    for pmid, article in tqdm(data.items(), desc="Preparing RE examples"):
        title = article['metadata']['title']
        abstract = article['metadata']['abstract']
        full_text, abstract_offset = create_full_text_with_offsets(title, abstract)
        
        entities = article['entities']
        relations = article['relations']
        
        # Adjust entity positions for full text
        adjusted_entities = [adjust_entity_positions(e, abstract_offset) for e in entities]
        
        # Create positive examples from annotated relations
        positive_pairs = set()
        for relation in relations:
            # Find subject and object entities
            subject = None
            obj = None
            
            for entity in adjusted_entities:
                if (entity['text_span'] == relation['subject_text_span'] and
                    entity['label'] == relation['subject_label']):
                    subject = entity
                if (entity['text_span'] == relation['object_text_span'] and
                    entity['label'] == relation['object_label']):
                    obj = entity
            
            if subject and obj:
                examples.append({
                    'text': full_text,
                    'subject': subject,
                    'object': obj,
                    'predicate': relation['predicate'],
                    'pmid': pmid
                })
                
                # Track positive pairs to avoid generating them as negatives
                pair_key = (subject['start_idx'], subject['end_idx'], 
                           obj['start_idx'], obj['end_idx'])
                positive_pairs.add(pair_key)
        
        # Generate negative examples - only from legal entity type pairs
        num_negatives = len(relations) * negative_multiplier
        negative_candidates = []
        
        # Create entity pairs but only for legal entity type combinations
        for i, subj in enumerate(adjusted_entities):
            for j, obj in enumerate(adjusted_entities):
                if i != j:  # Don't create self-relations
                    # Check if this entity type pair is legal
                    type_pair = (subj['label'], obj['label'])
                    if legal_pairs is None or type_pair in legal_pairs:
                        pair_key = (subj['start_idx'], subj['end_idx'],
                                   obj['start_idx'], obj['end_idx'])
                        
                        # Only add if not a positive relation
                        if pair_key not in positive_pairs:
                            negative_candidates.append({
                                'text': full_text,
                                'subject': subj,
                                'object': obj,
                                'predicate': 'no relation',
                                'pmid': pmid
                            })
        
        # Sample negative examples
        if negative_candidates:
            num_to_sample = min(num_negatives, len(negative_candidates))
            sampled_negatives = random.sample(negative_candidates, num_to_sample)
            examples.extend(sampled_negatives)

    print("RE example preparation function defined")

    return examples


In [None]:
# Prepare training examples
print("Preparing training examples...")
train_examples = prepare_re_examples(train_data, negative_multiplier=NEGATIVE_SAMPLE_MULTIPLIER, legal_pairs=legal_pairs)

# Count positive vs negative
positive_count = sum(1 for ex in train_examples if ex['predicate'] != 'no relation')
negative_count = sum(1 for ex in train_examples if ex['predicate'] == 'no relation')

print(f"\nTraining examples prepared: {len(train_examples)}")
print(f"  Positive examples: {positive_count}")
print(f"  Negative examples: {negative_count}")
print(f"  Ratio (neg/pos): {negative_count/positive_count:.2f}")

In [None]:
# Prepare dev examples
print("Preparing dev examples...")
dev_examples = prepare_re_examples(dev_data, negative_multiplier=NEGATIVE_SAMPLE_MULTIPLIER, legal_pairs=legal_pairs)

positive_count_dev = sum(1 for ex in dev_examples if ex['predicate'] != 'no relation')
negative_count_dev = sum(1 for ex in dev_examples if ex['predicate'] == 'no relation')

print(f"\nDev examples prepared: {len(dev_examples)}")
print(f"  Positive examples: {positive_count_dev}")
print(f"  Negative examples: {negative_count_dev}")

In [None]:
# Show example
print("\nExample training instance:")
example = train_examples[0]
print(f"  Text: {example['text'][:150]}...")
print(f"  Subject: '{example['subject']['text_span']}' [{example['subject']['label']}]")
print(f"  Object: '{example['object']['text_span']}' [{example['object']['label']}]")
print(f"  Predicate: {example['predicate']}")

## Initialize Tokenizer and Add Special Tokens

In [None]:
# Initialize tokenizer
print("Initializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# Add special entity marker tokens
special_tokens = {"additional_special_tokens": ["[E1]", "[/E1]", "[E2]", "[/E2]"]}
tokenizer.add_special_tokens(special_tokens)

# Get token IDs for entity markers
e1_token_id = tokenizer.convert_tokens_to_ids("[E1]")
e2_token_id = tokenizer.convert_tokens_to_ids("[E2]")

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
print(f"  Vocabulary size (with special tokens): {len(tokenizer)}")
print(f"  [E1] token ID: {e1_token_id}")
print(f"  [E2] token ID: {e2_token_id}")

## Tokenization with Entity Markers

In [None]:
def insert_entity_markers(text, subject, obj):
    """
    Insert entity marker tokens around subject and object entities.
    
    Args:
        text: Full text
        subject: Subject entity dict with start_idx, end_idx
        obj: Object entity dict with start_idx, end_idx
    
    Returns:
        Text with markers inserted
    """
    # Sort entities by position to insert markers correctly
    entities = [(subject['start_idx'], subject['end_idx'], '[E1]', '[/E1]'),
                (obj['start_idx'], obj['end_idx'], '[E2]', '[/E2]')]
    entities = sorted(entities, key=lambda x: x[0])
    
    # Insert markers from right to left to maintain positions
    marked_text = text
    offset = 0
    
    for start, end, start_marker, end_marker in entities:
        # Adjust positions with offset
        adj_start = start + offset
        adj_end = end + offset + 1  # +1 because end_idx is inclusive
        
        # Insert markers
        marked_text = (marked_text[:adj_start] + start_marker + 
                      marked_text[adj_start:adj_end] + end_marker + 
                      marked_text[adj_end:])
        
        # Update offset
        offset += len(start_marker) + len(end_marker)
    
    return marked_text


def tokenize_re_example(example, tokenizer, e1_token_id, e2_token_id, max_length=512):
    """
    Tokenize a relation extraction example with entity markers.
    
    Returns:
        Dictionary with input_ids, attention_mask, e1_mask, e2_mask, and label
    """
    # Insert entity markers
    marked_text = insert_entity_markers(
        example['text'],
        example['subject'],
        example['object']
    )
    
    # Tokenize
    encoding = tokenizer(
        marked_text,
        truncation=True,
        max_length=max_length,
        padding='max_length',
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].squeeze(0)
    attention_mask = encoding['attention_mask'].squeeze(0)
    
    # Create masks for [E1] and [E2] positions
    e1_mask = (input_ids == e1_token_id).long()
    e2_mask = (input_ids == e2_token_id).long()
    
    # Get label
    label = label2id[example['predicate']]
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'e1_mask': e1_mask,
        'e2_mask': e2_mask,
        'labels': torch.tensor(label, dtype=torch.long)
    }


print("Tokenization functions defined")

In [None]:
# Test tokenization
test_example = train_examples[0]
tokenized = tokenize_re_example(test_example, tokenizer, e1_token_id, e2_token_id)

print("Test tokenization:")
print(f"  Input IDs shape: {tokenized['input_ids'].shape}")
print(f"  E1 mask sum (should be 1): {tokenized['e1_mask'].sum().item()}")
print(f"  E2 mask sum (should be 1): {tokenized['e2_mask'].sum().item()}")
print(f"  Label: {tokenized['labels'].item()} ({id2label[tokenized['labels'].item()]})")

# Show marked text
marked = insert_entity_markers(test_example['text'], test_example['subject'], test_example['object'])
print(f"\nMarked text preview: {marked[:200]}...")

## Create Dataset Class

In [None]:
class REDataset(Dataset):
    """Custom dataset for Relation Extraction."""
    
    def __init__(self, examples, tokenizer, e1_token_id, e2_token_id, max_length=512):
        self.examples = examples
        self.tokenizer = tokenizer
        self.e1_token_id = e1_token_id
        self.e2_token_id = e2_token_id
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        return tokenize_re_example(
            example,
            self.tokenizer,
            self.e1_token_id,
            self.e2_token_id,
            self.max_length
        )


print("Dataset class defined")

In [None]:
# Create datasets
print("Creating datasets...")
train_dataset = REDataset(train_examples, tokenizer, e1_token_id, e2_token_id, max_length)
dev_dataset = REDataset(dev_examples, tokenizer, e1_token_id, e2_token_id, max_length)

print(f"Training dataset: {len(train_dataset)} examples")
print(f"Dev dataset: {len(dev_dataset)} examples")

## Initialize Model

In [None]:
# Initialize model
print("Initializing BERT RE model...")
model = BertForREWithEntityMarkers(model_name, num_labels=len(RELATION_LABELS))

# Resize token embeddings to account for new special tokens
model.bert.resize_token_embeddings(len(tokenizer))

print(f"Model initialized")
print(f"  Number of labels: {model.num_labels}")
print(f"  Hidden size: {model.bert.config.hidden_size}")

## Custom Trainer for Entity Marker Model

In [None]:
class RETrainer(Trainer):
    """Custom Trainer that handles entity marker masks."""

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs, labels=labels)
        loss = outputs['loss']
        return (loss, outputs) if return_outputs else loss


print("Custom trainer class defined")

## Configure Training Arguments

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir=output_model_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    push_to_hub=False,
    logging_steps=100,
    save_total_limit=2,
    seed=SEED,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

print("Training configuration ready")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Learning rate: {training_args.learning_rate}")

## Train Model

**Note:** This cell might take several minutes to hours depending on dataset size and hardware.

**Hyperparameters to experiment with:**
- `NEGATIVE_SAMPLE_MULTIPLIER`: Try 1, 2, 3, 5
- `learning_rate`: Try 1e-5, 2e-5, 3e-5
- `num_train_epochs`: Try 3, 5, 10
- `per_device_train_batch_size`: Adjust based on GPU memory
- Different pretrained models: "allenai/scibert_scivocab_uncased", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"

In [None]:
# Initialize trainer
print("Initializing Trainer...")
trainer = RETrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset
)

print("Trainer initialized")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Evaluation samples: {len(dev_dataset)}")

In [None]:
# Start training
print("="*60)
print("Starting model training...")
print("="*60)

import time
training_start_time = time.time()

train_result = trainer.train()

training_duration = time.time() - training_start_time

print("\n" + "="*60)
print("TRAINING COMPLETED!")
print("="*60)
print(f"Training time: {training_duration/60:.2f} minutes")

## Save Trained Model

In [None]:
# Save the trained model
print("Saving trained model...")

os.makedirs(output_model_dir, exist_ok=True)
trainer.save_model(output_model_dir)
tokenizer.save_pretrained(output_model_dir)

# Save label mappings
with open(os.path.join(output_model_dir, 'label_mappings.json'), 'w') as f:
    json.dump({'label2id': label2id, 'id2label': id2label}, f, indent=2)

print(f"Model saved to: {output_model_dir}")

## Load Model for Inference

In [None]:
# Load the trained model for inference
print("Loading trained model for inference...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer first (it has the correct vocabulary with special tokens)
inference_tokenizer = AutoTokenizer.from_pretrained(output_model_dir)

# Initialize model architecture
inference_model = BertForREWithEntityMarkers(model_name, num_labels=len(RELATION_LABELS))

# IMPORTANT: Resize embeddings BEFORE loading state dict to match saved model
inference_model.bert.resize_token_embeddings(len(inference_tokenizer))

# Try loading the model - handle both old and new formats
try:
    # Try new format first (model.safetensors or pytorch_model.bin in subfolder)
    state_dict_path = os.path.join(output_model_dir, "model.safetensors")
    if os.path.exists(state_dict_path):
        from safetensors.torch import load_file
        state_dict = load_file(state_dict_path)
    else:
        # Try old format
        state_dict_path = os.path.join(output_model_dir, "pytorch_model.bin")
        state_dict = torch.load(state_dict_path, map_location=device)
    
    inference_model.load_state_dict(state_dict)
    print(f"Loaded model from: {state_dict_path}")
except FileNotFoundError:
    # If no checkpoint file, try loading the full model directly
    # This happens when trainer saves the full model
    print("Standard checkpoint not found, trying alternative loading method...")
    
    # List files in the directory to debug
    if os.path.exists(output_model_dir):
        files = os.listdir(output_model_dir)
        print(f"Files in {output_model_dir}: {files}")
    
    # Try to load from checkpoint subdirectory
    checkpoint_dirs = [d for d in os.listdir(output_model_dir) if d.startswith('checkpoint-')]
    if checkpoint_dirs:
        # Use the last checkpoint
        checkpoint_dirs.sort(key=lambda x: int(x.split('-')[1]))
        checkpoint_path = os.path.join(output_model_dir, checkpoint_dirs[-1])
        print(f"Loading from checkpoint: {checkpoint_path}")
        
        # Load from checkpoint
        if os.path.exists(os.path.join(checkpoint_path, "model.safetensors")):
            from safetensors.torch import load_file
            state_dict = load_file(os.path.join(checkpoint_path, "model.safetensors"))
        else:
            state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"), map_location=device)
        
        inference_model.load_state_dict(state_dict)
    else:
        raise FileNotFoundError(f"Could not find model checkpoint in {output_model_dir}")

inference_model.to(device)
inference_model.eval()

# Reload token IDs
e1_token_id = inference_tokenizer.convert_tokens_to_ids("[E1]")
e2_token_id = inference_tokenizer.convert_tokens_to_ids("[E2]")

print(f"Model loaded successfully")
print(f"  Device: {device}")

## Inference Function

In [None]:
def predict_relation(model, tokenizer, text, subject, obj, e1_token_id, e2_token_id, id2label, device):
    """
    Predict relation between two entities.
    
    Returns:
        Predicted relation label and confidence score
    """
    # Insert entity markers
    marked_text = insert_entity_markers(text, subject, obj)
    
    # Tokenize
    encoding = tokenizer(
        marked_text,
        truncation=True,
        max_length=512,
        padding='max_length',
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Create entity masks
    e1_mask = (input_ids == e1_token_id).long()
    e2_mask = (input_ids == e2_token_id).long()
    
    # Predict
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            e1_mask=e1_mask,
            e2_mask=e2_mask
        )
        logits = outputs['logits']
        probs = torch.softmax(logits, dim=-1)
        
        pred_label_id = torch.argmax(logits, dim=-1).item()
        confidence = probs[0, pred_label_id].item()
    
    pred_label = id2label[pred_label_id]
    
    return pred_label, confidence


print("Inference function defined")

## Predict on Dev Set

In [None]:
# Run inference on dev set
print("Running inference on dev set...")

predictions = {}

for pmid, article in tqdm(dev_data.items(), desc="Predicting relations"):
    title = article['metadata']['title']
    abstract = article['metadata']['abstract']
    full_text, abstract_offset = create_full_text_with_offsets(title, abstract)
    
    entities = article['entities']
    adjusted_entities = [adjust_entity_positions(e, abstract_offset) for e in entities]
    
    predicted_relations = []
    
    # Generate entity pairs - only for legal entity type combinations
    for i, subject in enumerate(adjusted_entities):
        for j, obj in enumerate(adjusted_entities):
            if i != j:  # Don't create self-relations
                # Check if this entity type pair is legal
                type_pair = (subject['label'], obj['label'])
                if type_pair not in legal_pairs:
                    continue  # Skip non-legal entity type pairs
                
                # Predict relation
                pred_label, confidence = predict_relation(
                    inference_model,
                    inference_tokenizer,
                    full_text,
                    subject,
                    obj,
                    e1_token_id,
                    e2_token_id,
                    id2label,
                    device
                )
                
                # Only keep positive relations (not "no relation")
                if pred_label != "no relation":
                    # Map back to original entity positions
                    orig_subject = next(e for e in entities 
                                       if e['text_span'] == subject['text_span'] and 
                                          e['label'] == subject['label'])
                    orig_obj = next(e for e in entities 
                                   if e['text_span'] == obj['text_span'] and 
                                      e['label'] == obj['label'])
                    
                    predicted_relations.append({
                        "subject_start_idx": orig_subject['start_idx'],
                        "subject_end_idx": orig_subject['end_idx'],
                        "subject_location": orig_subject['location'],
                        "subject_text_span": orig_subject['text_span'],
                        "subject_label": orig_subject['label'],
                        "predicate": pred_label,
                        "object_start_idx": orig_obj['start_idx'],
                        "object_end_idx": orig_obj['end_idx'],
                        "object_location": orig_obj['location'],
                        "object_text_span": orig_obj['text_span'],
                        "object_label": orig_obj['label']
                    })
    

    predictions[pmid] = {"relations": predicted_relations}

total_relations = sum(len(p['relations']) for p in predictions.values())
print(f"\nInference completed: {len(predictions)} documents")
print(f"  Total relations predicted: {total_relations}")

## Save Predictions

In [None]:
# Save predictions to file
output_path = "predictions/bert_re_predictions.json"

os.makedirs("predictions", exist_ok=True)

with open(output_path, "w", encoding="utf-8") as f:
    json.dump(predictions, f, ensure_ascii=False, indent=2)

print(f"Predictions saved to {output_path}")

## Example Predictions

In [None]:
# Show example predictions
print("Example Predictions:\n")

sample_pmids = list(dev_data.keys())[:5]

for pmid in sample_pmids:
    article = dev_data[pmid]
    pred = predictions[pmid]
    
    print(f"Document PMID: {pmid}")
    print(f"Title: {article['metadata']['title'][:100]}...")
    print(f"\nNumber of entities: {len(article['entities'])}")
    print(f"Number of gold relations: {len(article['relations'])}")
    print(f"Number of predicted relations: {len(pred['relations'])}")
    
    # Show first few predicted relations
    print("\nSample predicted relations:")
    for relation in pred['relations'][:5]:
        print(f"  ({relation['subject_text_span']} [{relation['subject_label']}]) "
              f"--[{relation['predicate']}]--> "
              f"({relation['object_text_span']} [{relation['object_label']}])")
    
    print("-" * 80)
    print()

## Analysis: Relation Distribution by Predicate

In [None]:
# Count relations by predicate in predictions
pred_predicate_counts = Counter()
for pmid, pred in predictions.items():
    for relation in pred['relations']:
        pred_predicate_counts[relation['predicate']] += 1

# Count relations by predicate in gold standard
gold_predicate_counts = Counter()
for pmid, article in dev_data.items():
    for relation in article['relations']:
        gold_predicate_counts[relation['predicate']] += 1

print("Relation Distribution by Predicate:")
print("="*60)
print(f"{'Predicate':<30} {'Gold':<10} {'Predicted':<10}")
print("-"*60)

all_predicates = set(gold_predicate_counts.keys()) | set(pred_predicate_counts.keys())
for predicate in sorted(all_predicates):
    print(f"{predicate:<30} {gold_predicate_counts[predicate]:<10} {pred_predicate_counts[predicate]:<10}")

print("-"*60)
print(f"{'TOTAL':<30} {sum(gold_predicate_counts.values()):<10} {sum(pred_predicate_counts.values()):<10}")

## Processing to evaluation format

Remove relations not defined in the annotation guidelines and complete conversion to evaluation format

In [None]:
import json

output_path = "predictions/rebel_finetuned_predictions.json"
predictions = json.load(open(output_path, 'r', encoding='utf-8'))
dev_data = json.load(open("../data/Annotations/Dev/json_format/dev.json"))

dump_dict = {}
for pmid, content in predictions.items():
    dump_dict[pmid] = {
        'metadata': dev_data[pmid]['metadata'],
        'entities': dev_data[pmid]['entities'],
        'relations': content['relations']
    }

predictions = dump_dict
dump_dict = {}

In [None]:
LEGAL_RELATIONS = [
    ("DDF", "affect", "DDF"),
    ("microbiome", "is linked to", "DDF"),
    ("DDF", "target", "human"),
    ("drug", "change effect", "DDF"),
    ("DDF", "is a", "DDF"),
    ("microbiome", "located in", "human"),
    ("chemical", "influence", "DDF"),
    ("dietary supplement", "influence", "DDF"),
    ("DDF", "target", "animal"),
    ("chemical", "impact", "microbiome"),
    ("anatomical location", "located in", "animal"),
    ("microbiome", "located in", "animal"),
    ("chemical", "located in", "anatomical location"),
    ("bacteria", "part of", "microbiome"),
    ("DDF", "strike", "anatomical location"),
    ("drug", "administered", "animal"),
    ("bacteria", "influence", "DDF"),
    ("drug", "impact", "microbiome"),
    ("DDF", "change abundance", "microbiome"),
    ("microbiome", "located in", "anatomical location"),
    ("microbiome", "used by", "biomedical technique"),
    ("chemical", "produced by", "microbiome"),
    ("dietary supplement", "impact", "microbiome"),
    ("bacteria", "located in", "animal"),
    ("animal", "used by", "biomedical technique"),
    ("chemical", "impact", "bacteria"),
    ("chemical", "located in", "animal"),
    ("food", "impact", "bacteria"),
    ("microbiome", "compared to", "microbiome"),
    ("human", "used by", "biomedical technique"),
    ("bacteria", "change expression", "gene"),
    ("chemical", "located in", "human"),
    ("drug", "interact", "chemical"),
    ("food", "administered", "human"),
    ("DDF", "change abundance", "bacteria"),
    ("chemical", "interact", "chemical"),
    ("chemical", "part of", "chemical"),
    ("dietary supplement", "impact", "bacteria"),
    ("DDF", "interact", "chemical"),
    ("food", "impact", "microbiome"),
    ("food", "influence", "DDF"),
    ("bacteria", "located in", "human"),
    ("dietary supplement", "administered", "human"),
    ("bacteria", "interact", "chemical"),
    ("drug", "change expression", "gene"),
    ("drug", "impact", "bacteria"),
    ("drug", "administered", "human"),
    ("anatomical location", "located in", "human"),
    ("dietary supplement", "change expression", "gene"),
    ("chemical", "change expression", "gene"),
    ("bacteria", "interact", "bacteria"),
    ("drug", "interact", "drug"),
    ("microbiome", "change expression", "gene"),
    ("bacteria", "interact", "drug"),
    ("food", "change expression", "gene")
]

def remove_illegal_relations(data):
    dump_dict = {}
    total_rels = 0
    kept_rels = 0
    discared_rels = 0
    discared_rels_set = set()

    for pmid, article in data.items():
        dump_dict[pmid] = {}
        dump_dict[pmid]['metadata'] = article['metadata']
        dump_dict[pmid]['entities'] = []
        dump_dict[pmid]['relations'] = []

        for entity in article['entities']:
            dump_dict[pmid]['entities'].append({
                "start_idx": entity["start_idx"],
                "end_idx": entity["end_idx"],
                "location": entity["location"],
                "text_span": entity["text_span"],
                "label": entity["label"] if entity['label'] != 'DDF' else 'DDF'
            })
        
        for relation in article['relations']:
            total_rels += 1
            rel_key = (relation["subject_label"], relation["predicate"], relation["object_label"])
            if rel_key in LEGAL_RELATIONS:
                kept_rels += 1
                dump_dict[pmid]['relations'].append({
                    "subject_start_idx": relation["subject_start_idx"],
                    "subject_end_idx": relation["subject_end_idx"],
                    "subject_location": relation["subject_location"],
                    "subject_text_span": relation["subject_text_span"],
                    "subject_label": relation["subject_label"] if relation["subject_label"] != 'DDF' else 'DDF',
                    "predicate": relation["predicate"],
                    "object_start_idx": relation["object_start_idx"],
                    "object_end_idx": relation["object_end_idx"],
                    "object_location": relation["object_location"],
                    "object_text_span": relation["object_text_span"],
                    "object_label": relation["object_label"] if relation["object_label"] != 'DDF' else 'DDF'
                })
            else:
                discared_rels += 1
                discared_rels_set.add(rel_key)

    print(f'total_rels: {total_rels}')
    print(f'kept_rels: {kept_rels}')
    print(f'discared_rels: {discared_rels}')
    print()
    print(f'discared_rels_set: {discared_rels_set}')
    for entry in discared_rels_set:
        print(entry)

    return dump_dict


In [None]:
dump_dict = remove_illegal_relations(predictions)

Sort entities and relations

In [None]:
def sort_entities(release_dict):
	def get_sorting_key(entity):
		location_priority = 0 if entity["location"] == "title" else 1
		return (location_priority, entity["start_idx"])

	for pmid, article in release_dict.items():
		article["entities"].sort(key=get_sorting_key)

In [None]:
sort_entities(predictions)

In [None]:
def sort_relations(release_dict):
	def get_sorting_key(relation):
		location_priority = 0 if relation["subject_location"] == "title" else 1
		return (location_priority, relation["subject_start_idx"])

	for pmid, article in release_dict.items():
		article["relations"].sort(key=get_sorting_key)

In [None]:
sort_relations(predictions)

Generate Binary Tag Based Relations

In [None]:
def add_binary_tag_based_relations_to_release_dict(release_dict):
    for pmid, article in release_dict.items():
        pairs = set()
        for relation in article["relations"]:
            pairs.add((relation["subject_label"], relation["object_label"]))
        if "binary_tag_based_relations" not in release_dict[pmid]:    
            release_dict[pmid]["binary_tag_based_relations"] = []
        for entry in pairs:
            release_dict[pmid]["binary_tag_based_relations"].append({"subject_label": entry[0], "object_label": entry[1]})

In [None]:
add_binary_tag_based_relations_to_release_dict(predictions)

Generate Ternary Tag Based Relations

In [None]:
def add_ternary_tag_based_relations_to_release_dict(release_dict):
    for pmid, article in release_dict.items():
        triplets = set()
        for relation in article["relations"]:
            triplets.add((relation["subject_label"], relation["predicate"], relation["object_label"]))
        if "ternary_tag_based_relations" not in release_dict[pmid]:
            release_dict[pmid]["ternary_tag_based_relations"] = []
        for entry in triplets:
            release_dict[pmid]["ternary_tag_based_relations"].append({"subject_label": entry[0], "predicate": entry[1], "object_label": entry[2]})

In [None]:

add_ternary_tag_based_relations_to_release_dict(predictions)

Generate Ternary Mention Based Relations

In [None]:
def add_ternary_mention_based_relations_to_release_dict(release_dict):
    for pmid, article in release_dict.items():
        tuples = set()
        for relation in article["relations"]:
            tuples.add((relation["subject_text_span"], relation["subject_label"], relation["predicate"], relation["object_text_span"], relation["object_label"]))
        if "ternary_mention_based_relations" not in release_dict[pmid]:
            release_dict[pmid]["ternary_mention_based_relations"] = []		
        for entry in tuples:
            release_dict[pmid]["ternary_mention_based_relations"].append({"subject_text_span": entry[0], "subject_label": entry[1], "predicate": entry[2], "object_text_span": entry[3], "object_label": entry[4]})

In [None]:
add_ternary_mention_based_relations_to_release_dict(predictions)

In [None]:
output_path = output_path.replace(".json", "_eval_format.json")

with open(output_path, 'w', encoding='utf-8') as file:
    json.dump(predictions, file, indent=2)