In [None]:
!wget -O data.zip https://github.com/MMartinelli-hub/ATA_Tutorship/raw/refs/heads/main/gutbrainie/data/GutBrainIE_Full_Collection_2025.zip
!unzip data.zip "*"
!mkdir data
!mv Annotations/ ./data
!mv Articles/ ./data
!mv Test_Data/ ./data
!rm -rf data.zip

# REBEL: Relation Extraction By End-to-end Language Generation

This notebook demonstrates **REBEL** (Relation Extraction By End-to-end Language generation), a seq2seq model based on BART that performs end-to-end relation extraction.

## Key Features:
- **End-to-end approach**: No separate entity extraction and relation classification steps
- **Sequence-to-sequence**: Generates relation triplets as linearized text sequences
- **Zero-shot capable**: Can extract relations without task-specific fine-tuning
- **Format**: Generates triplets in format: `<triplet> subject <subj> object <obj> predicate`

## References:
- Paper: [REBEL: Relation Extraction By End-to-end Language generation (EMNLP 2021)](https://aclanthology.org/2021.findings-emnlp.204/)
- Model: [Babelscape/rebel-large](https://huggingface.co/Babelscape/rebel-large)
- Repository: [babelscape/rebel](https://github.com/babelscape/rebel)

## Setup and Imports

In [None]:
import json
import os
from typing import List, Dict
from tqdm import tqdm
from collections import Counter
from transformers import pipeline
import torch

print("Setup complete")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load REBEL Model

We'll use the pre-trained REBEL-large model from Hugging Face. This model is trained on multiple RE datasets and can extract over 200 relation types.

In [None]:
# Initialize REBEL pipeline
print("Loading REBEL model...")
device = 0 if torch.cuda.is_available() else -1  # Use GPU if available

triplet_extractor = pipeline(
    'text2text-generation',
    model='Babelscape/rebel-large',
    tokenizer='Babelscape/rebel-large',
    device=device
)

print(f"REBEL model loaded")
print(f"  Device: {'GPU' if device == 0 else 'CPU'}")

## Triplet Extraction Function

REBEL outputs triplets in a special format using tokens like `<triplet>`, `<subj>`, and `<obj>`. We need to parse this output.

In [None]:
def extract_triplets(text: str) -> List[Dict[str, str]]:
    """
    Parse REBEL output to extract triplets.
    
    REBEL outputs format: <triplet> subject <subj> object <obj> predicate <triplet> ...
    
    Args:
        text: Generated text from REBEL model
    
    Returns:
        List of triplets with 'head', 'type', and 'tail' keys
    """
    triplets = []
    relation, subject, object_ = '', '', ''
    text = text.strip()
    current = 'x'
    
    # Remove special tokens and process
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
            
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
            
        elif token == "<obj>":
            current = 'o'
            relation = ''
            
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    
    # Add last triplet
    if subject != '' and relation != '' and object_ != '':
        triplets.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    
    return triplets


print("Triplet extraction function defined")

## Test REBEL on Example Text

In [None]:
# Test with example text
example_text = "Punta Cana is a resort town in the municipality of HigÃ¼ey, in La Altagracia Province, the easternmost province of the Dominican Republic."

print(f"Input text:\n{example_text}\n")

# Generate triplets
generated_output = triplet_extractor(
    example_text,
    return_tensors=True,
    return_text=False
)

# Decode and extract triplets
extracted_text = triplet_extractor.tokenizer.batch_decode(
    [generated_output[0]["generated_token_ids"]]
)

print(f"Raw REBEL output:\n{extracted_text[0]}\n")

# Parse triplets
triplets = extract_triplets(extracted_text[0])

print(f"Extracted triplets ({len(triplets)}):")
for i, triplet in enumerate(triplets, 1):
    print(f"  {i}. ({triplet['head']}) --[{triplet['type']}]--> ({triplet['tail']})")

## Load GutBrainIE Data

In [None]:
def load_re_data(file_paths):
    """Load relation extraction data from multiple JSON files."""
    all_data = {}
    
    for file_path in file_paths:
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            all_data.update(data)
            print(f"Loaded {len(data)} documents from {os.path.basename(file_path)}")
        else:
            print(f"Warning: {file_path} not found")
    
    return all_data


# Load dev data for testing
dev_data = load_re_data(["../data/Annotations/Dev/json_format/dev.json"])
print(f"\nTotal dev documents: {len(dev_data)}")

## Load Training Data

In [None]:
# Load training data from three quality levels
train_files = [
    "../data/Annotations/Train/gold_quality/json_format/train_gold.json",
    "../data/Annotations/Train/platinum_quality/json_format/train_platinum.json",
    "../data/Annotations/Train/silver_quality/json_format/train_silver.json"
]

train_data = load_re_data(train_files)
print(f"\nTotal training documents: {len(train_data)}")

## Prepare Training Data in REBEL Format

REBEL expects data in a specific linearized format:
- Input: Raw text
- Output: `<triplet> subject <subj> object <obj> predicate <triplet> ...`

In [None]:
def create_full_text(title, abstract):
    """Create full text from title and abstract."""
    return f"{title} {abstract}"

def format_triplets_for_rebel(relations):
    """
    Convert relations to REBEL format string.
    
    Args:
        relations: List of relation dictionaries
    
    Returns:
        Formatted string in REBEL format
    """
    if not relations:
        return ""
    
    # Group relations by subject
    relations_by_subject = {}
    for rel in relations:
        key = (rel['subject_text_span'], rel['subject_label'])
        if key not in relations_by_subject:
            relations_by_subject[key] = []
        relations_by_subject[key].append(rel)
    
    # Build REBEL format string
    triplet_strings = []
    for (subject_text, subject_label), rels in relations_by_subject.items():
        # Start with subject
        parts = [subject_text]
        
        # Add all objects and predicates for this subject
        for rel in rels:
            parts.append('<subj>')
            parts.append(rel['object_text_span'])
            parts.append('<obj>')
            parts.append(rel['predicate'])
        
        triplet_strings.append(' '.join(parts))
    
    return '<triplet> ' + ' <triplet> '.join(triplet_strings)


def prepare_training_data(data):
    """
    Prepare training examples in REBEL format.
    
    Returns:
        List of dicts with 'input' and 'output' keys
    """
    examples = []
    
    for pmid, article in tqdm(data.items(), desc="Preparing training data"):
        title = article['metadata']['title']
        abstract = article['metadata']['abstract']
        full_text = create_full_text(title, abstract)
        relations = article['relations']
        
        # Skip documents without relations
        if not relations:
            continue
        
        # Format triplets
        triplet_string = format_triplets_for_rebel(relations)
        
        examples.append({
            'input': full_text,
            'output': triplet_string,
            'pmid': pmid
        })
    
    return examples


print("Training data preparation functions defined")

In [None]:
# Prepare training and validation examples
train_examples = prepare_training_data(train_data)
dev_examples = prepare_training_data(dev_data)

print(f"\nTraining examples: {len(train_examples)}")
print(f"Validation examples: {len(dev_examples)}")

# Show example
if train_examples:
    example = train_examples[0]
    print(f"\nExample training instance:")
    print(f"  Input: {example['input'][:150]}...")
    print(f"  Output: {example['output'][:150]}...")

## Create Dataset for Fine-tuning

In [None]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq


class REBELDataset(Dataset):
    """Dataset for REBEL fine-tuning."""
    
    def __init__(self, examples, tokenizer, max_input_length=512, max_target_length=256):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Tokenize input
        model_inputs = self.tokenizer(
            example['input'],
            max_length=self.max_input_length,
            truncation=True,
            padding=False  # Padding will be done by collator
        )
        
        # Tokenize target
        labels = self.tokenizer(
            text_target=example['output'],
            max_length=self.max_target_length,
            truncation=True,
            padding=False
        )
        
        model_inputs["labels"] = labels["input_ids"]
        
        return model_inputs


# Initialize tokenizer for training
tokenizer = AutoTokenizer.from_pretrained('Babelscape/rebel-large')

# Create datasets
train_dataset = REBELDataset(train_examples, tokenizer)
val_dataset = REBELDataset(dev_examples, tokenizer)

print(f"Training dataset: {len(train_dataset)} examples")
print(f"Validation dataset: {len(val_dataset)} examples")

## Fine-tune REBEL Model

We'll fine-tune the pre-trained REBEL model on GutBrainIE data to adapt it to the biomedical domain and our specific relation types.

In [None]:
# Load model for fine-tuning
model = AutoModelForSeq2SeqLM.from_pretrained('Babelscape/rebel-large')

# Training configuration
output_dir = "models/rebel_finetuned"

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  # Effective batch size = 16
    num_train_epochs=5,
    weight_decay=0.01,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    predict_with_generate=True,
    generation_max_length=256,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to="none",
    push_to_hub=False,
)

# Data collator for dynamic padding
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True
)

print("Model and training configuration ready")
print(f"  Output directory: {output_dir}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")

In [None]:
# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("Trainer initialized")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

## Start Training

**Note:** This will take considerable time depending on hardware (several hours on CPU, ~30-60 minutes on GPU).

In [None]:
# Start training
print("="*60)
print("Starting REBEL fine-tuning...")
print("="*60)

import time
training_start_time = time.time()

train_result = trainer.train()

training_duration = time.time() - training_start_time

print("\n" + "="*60)
print("TRAINING COMPLETED!")
print("="*60)
print(f"Training time: {training_duration/60:.2f} minutes")
print(f"Final train loss: {train_result.training_loss:.4f}")

## Save Fine-tuned Model

In [None]:
# Save the fine-tuned model
print("Saving fine-tuned model...")

os.makedirs(output_dir, exist_ok=True)
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model saved to: {output_dir}")
print("\nYou can now use this fine-tuned model by loading:")
print(f"  model_path = '{output_dir}'")
print(f"  triplet_extractor = pipeline('text2text-generation', model=model_path, tokenizer=model_path)")

## Load Fine-tuned Model for Inference

Now we'll use the fine-tuned model instead of the pre-trained one for predictions.

In [None]:
# Load the fine-tuned model for inference
print("Loading fine-tuned REBEL model for inference...")

# Use fine-tuned model if it exists, otherwise use pre-trained
model_path = output_dir if os.path.exists(output_dir) else 'Babelscape/rebel-large'

triplet_extractor = pipeline(
    'text2text-generation',
    model=model_path,
    tokenizer=model_path,
    device=device
)

print(f"Model loaded from: {model_path}")
print(f"  Device: {'GPU' if device == 0 else 'CPU'}")

---

## Inference with Fine-tuned Model

The sections below will use the fine-tuned model for predictions on the dev set.

## Prediction Function with Entity Matching

REBEL extracts entities and relations from raw text. We need to match them back to the entities in our dataset.

In [None]:
def match_entity_to_annotations(entity_text, entities_list, full_text):
    """
    Match an extracted entity text to annotated entities.
    
    Args:
        entity_text: Text span extracted by REBEL
        entities_list: List of annotated entities from dataset
        full_text: Full document text
    
    Returns:
        Matched entity dict or None
    """
    entity_text_lower = entity_text.lower().strip()
    
    # Try exact match first
    for entity in entities_list:
        if entity['text_span'].lower().strip() == entity_text_lower:
            return entity
    
    # Try substring match
    for entity in entities_list:
        entity_span_lower = entity['text_span'].lower().strip()
        if entity_text_lower in entity_span_lower or entity_span_lower in entity_text_lower:
            return entity
    
    return None


def predict_relations_rebel(pmid, article, triplet_extractor, max_length=512):
    """
    Predict relations using REBEL model.
    
    Args:
        pmid: Document ID
        article: Article data with metadata and entities
        triplet_extractor: REBEL pipeline
        max_length: Maximum input length for generation
    
    Returns:
        List of predicted relations
    """
    title = article['metadata']['title']
    abstract = article['metadata']['abstract']
    full_text = create_full_text(title, abstract)
    entities = article['entities']
    
    # Adjust entity positions for title+abstract concatenation
    abstract_offset = len(title) + 1
    adjusted_entities = []
    for entity in entities:
        if entity['location'] == 'abstract':
            adjusted_entities.append({
                'start_idx': entity['start_idx'],
                'end_idx': entity['end_idx'],
                'text_span': entity['text_span'],
                'label': entity['label'],
                'location': entity['location']
            })
        else:
            adjusted_entities.append(entity)
    
    # Generate triplets using REBEL
    try:
        generated_output = triplet_extractor(
            full_text,
            return_tensors=True,
            return_text=False,
            max_length=max_length
        )
        
        extracted_text = triplet_extractor.tokenizer.batch_decode(
            [generated_output[0]["generated_token_ids"]]
        )
        
        triplets = extract_triplets(extracted_text[0])
    except Exception as e:
        print(f"Error processing {pmid}: {e}")
        return []
    
    # Match triplets to annotated entities
    predicted_relations = []
    
    for triplet in triplets:
        # Find matching entities
        subject_entity = match_entity_to_annotations(triplet['head'], entities, full_text)
        object_entity = match_entity_to_annotations(triplet['tail'], entities, full_text)
        
        # Only add if both entities matched
        if subject_entity and object_entity:
            predicted_relations.append({
                "subject_start_idx": subject_entity['start_idx'],
                "subject_end_idx": subject_entity['end_idx'],
                "subject_location": subject_entity['location'],
                "subject_text_span": subject_entity['text_span'],
                "subject_label": subject_entity['label'],
                "predicate": triplet['type'],
                "object_start_idx": object_entity['start_idx'],
                "object_end_idx": object_entity['end_idx'],
                "object_location": object_entity['location'],
                "object_text_span": object_entity['text_span'],
                "object_label": object_entity['label']
            })
    
    return predicted_relations


print("Prediction functions defined")

## Test on Single Document

In [None]:
# Test on first document
test_pmid = list(dev_data.keys())[0]
test_article = dev_data[test_pmid]

print(f"Testing on document: {test_pmid}")
print(f"Title: {test_article['metadata']['title'][:100]}...")
print(f"\nNumber of annotated entities: {len(test_article['entities'])}")
print(f"Number of gold relations: {len(test_article['relations'])}")

# Predict relations
predicted_relations = predict_relations_rebel(test_pmid, test_article, triplet_extractor)

print(f"\nPredicted {len(predicted_relations)} relations")
print("\nSample predicted relations:")
for relation in predicted_relations[:5]:
    print(f"  ({relation['subject_text_span']} [{relation['subject_label']}]) "
          f"--[{relation['predicate']}]--> "
          f"({relation['object_text_span']} [{relation['object_label']}])")

## Run Predictions on Full Dev Set

**Note:** This may take some time depending on dataset size and hardware.

In [None]:
# Run inference on full dev set
print("Running REBEL inference on dev set...")

predictions = {}

for pmid, article in tqdm(dev_data.items(), desc="Predicting relations"):
    predicted_relations = predict_relations_rebel(pmid, article, triplet_extractor)
    predictions[pmid] = {"relations": predicted_relations}

total_relations = sum(len(p['relations']) for p in predictions.values())
print(f"\nInference completed: {len(predictions)} documents")
print(f"  Total relations predicted: {total_relations}")

## Save Predictions

In [None]:
# Save predictions to file
# Use different filename depending on whether model was fine-tuned
model_type = "finetuned" if os.path.exists(output_dir) else "pretrained"
output_path = f"predictions/rebel_{model_type}_predictions.json"

os.makedirs("predictions", exist_ok=True)

with open(output_path, "w", encoding="utf-8") as f:
    json.dump(predictions, f, ensure_ascii=False, indent=2)

print(f"Predictions saved to {output_path}")
print(f"  Model type: {model_type}")

## Analyze Predictions

In [None]:
# Show example predictions
print("Example Predictions:\n")

sample_pmids = list(dev_data.keys())[:3]

for pmid in sample_pmids:
    article = dev_data[pmid]
    pred = predictions[pmid]
    
    print(f"Document PMID: {pmid}")
    print(f"Title: {article['metadata']['title'][:80]}...")
    print(f"\nNumber of entities: {len(article['entities'])}")
    print(f"Number of gold relations: {len(article['relations'])}")
    print(f"Number of predicted relations: {len(pred['relations'])}")
    
    # Show predicted relations
    print("\nPredicted relations:")
    for relation in pred['relations'][:5]:
        print(f"  ({relation['subject_text_span']} [{relation['subject_label']}]) "
              f"--[{relation['predicate']}]--> "
              f"({relation['object_text_span']} [{relation['object_label']}])")
    
    print("-" * 80)
    print()

## Relation Distribution Analysis

In [None]:
# Count relations by predicate in predictions
pred_predicate_counts = Counter()
for pmid, pred in predictions.items():
    for relation in pred['relations']:
        pred_predicate_counts[relation['predicate']] += 1

# Count relations by predicate in gold standard
gold_predicate_counts = Counter()
for pmid, article in dev_data.items():
    for relation in article['relations']:
        gold_predicate_counts[relation['predicate']] += 1

print("Relation Distribution by Predicate:")
print("="*70)
print(f"{'Predicate':<40} {'Gold':<10} {'Predicted':<10}")
print("-"*70)

all_predicates = sorted(set(gold_predicate_counts.keys()) | set(pred_predicate_counts.keys()))
for predicate in all_predicates:
    print(f"{predicate:<40} {gold_predicate_counts.get(predicate, 0):<10} "
          f"{pred_predicate_counts.get(predicate, 0):<10}")

print("-"*70)
print(f"{'TOTAL':<40} {sum(gold_predicate_counts.values()):<10} "
      f"{sum(pred_predicate_counts.values()):<10}")

print(f"\nREBEL predicted {len(pred_predicate_counts)} unique relation types")
print(f"Dataset contains {len(gold_predicate_counts)} unique relation types")