# Relation Extraction with Russian BERT

This notebook fine-tunes a BERT model (`DeepPavlov/rubert-base-cased`) for Relation Extraction (RE) on the provided dataset.

## Approach
We use the **Entity Marker** approach:
1. For each pair of entities in a sentence, we mark them with special tokens `<e1>`, `</e1>`, `<e2>`, `</e2>`.
2. The model classifies the relation between the marked entities.
3. If no relation exists in the dataset for a pair, we assign the label `NO_RELATION`.

In [1]:
!pip install transformers datasets torch scikit-learn tqdm

In [2]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm.auto import tqdm
import os

# Configuration
DATASET_PATH = r"..\datasets\process\RE_dataset"
MODEL_NAME = "DeepPavlov/rubert-base-cased"
MAX_LEN = 128
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 1. Load and Process Data

In [None]:
def load_data(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError:
                print(f"Skipping invalid line: {line[:50]}...")
    return data

raw_data = load_data(DATASET_PATH)
print(f"Loaded {len(raw_data)} sentences.")

In [None]:
def process_examples(data):
    examples = []
    relation_types = set()
    
    for entry in data:
        tokens = entry['tokens']
        entities = entry['entities']
        relations = entry.get('relations', [])
        
        # Map (head_start_idx, tail_start_idx) -> relation_label
        rel_map = {}
        for rel in relations:
            # rel is [head_token_idx, tail_token_idx, label]
            head_idx, tail_idx, label = rel
            rel_map[(head_idx, tail_idx)] = label
            relation_types.add(label)
            
        # Generate all pairs of entities
        for i, e1 in enumerate(entities):
            for j, e2 in enumerate(entities):
                if i == j:
                    continue
                
                e1_start, e1_end, e1_type = e1
                e2_start, e2_end, e2_type = e2
                
                # Check if relation exists
                # The dataset uses start token index for relation mapping
                label = rel_map.get((e1_start, e2_start), "NO_RELATION")
                if label == "NO_RELATION":
                    # Optional: Downsample NO_RELATION if too many
                    pass
                
                examples.append({
                    'tokens': tokens,
                    'e1_span': (e1_start, e1_end),
                    'e2_span': (e2_start, e2_end),
                    'label': label
                })
                
    return examples, sorted(list(relation_types))

examples, relation_labels = process_examples(raw_data)
if "NO_RELATION" not in relation_labels:
    relation_labels.append("NO_RELATION")
    
label2id = {l: i for i, l in enumerate(relation_labels)}
id2label = {i: l for l, i in label2id.items()}

print(f"Generated {len(examples)} examples.")
print(f"Relation types: {relation_labels}")

## 2. Tokenization and Dataset Class
We insert special markers `<e1>`, `</e1>`, `<e2>`, `</e2>` around the subject and object entities.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Add special tokens for entity markers
special_tokens = {'additional_special_tokens': ['<e1>', '</e1>', '<e2>', '</e2>']}
tokenizer.add_special_tokens(special_tokens)

class REDataset(Dataset):
    def __init__(self, examples, tokenizer, label2id, max_len=128):
        self.examples = examples
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_len = max_len
        
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        ex = self.examples[idx]
        tokens = ex['tokens']
        e1_start, e1_end = ex['e1_span']
        e2_start, e2_end = ex['e2_span']
        
        # Construct sentence with markers
        # We need to be careful with indices as we insert tokens
        # Strategy: Reconstruct the list of tokens with markers inserted
        
        new_tokens = []
        for i, token in enumerate(tokens):
            if i == e1_start:
                new_tokens.append('<e1>')
            if i == e2_start:
                new_tokens.append('<e2>')
            
            new_tokens.append(token)
            
            if i == e1_end:
                new_tokens.append('</e1>')
            if i == e2_end:
                new_tokens.append('</e2>')
                
        text = " ".join(new_tokens)
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.label2id[ex['label']], dtype=torch.long)
        }

train_ex, val_ex = train_test_split(examples, test_size=0.2, random_state=42)

train_dataset = REDataset(train_ex, tokenizer, label2id, MAX_LEN)
val_dataset = REDataset(val_ex, tokenizer, label2id, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

## 3. Model Setup

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id)
)
# Resize embeddings because we added special tokens
model.resize_token_embeddings(len(tokenizer))
model.to(DEVICE)

## 4. Training Loop

In [None]:
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0, 
    num_training_steps=total_steps
)

def train_epoch(model, data_loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
    return total_loss / len(data_loader)

def eval_model(model, data_loader, device):
    model.eval()
    preds = []
    true_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            _, predicted = torch.max(outputs.logits, dim=1)
            preds.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            
    return true_labels, preds

In [None]:
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, DEVICE)
    print(f"Train loss: {train_loss:.4f}")
    
    true_labels, preds = eval_model(model, val_loader, DEVICE)
    print(classification_report(true_labels, preds, labels=list(label2id.values()), target_names=list(label2id.keys())))

## 5. Save Model

In [None]:
output_dir = "./re_bert_model"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")