In [1]:
import pandas as pd
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, ConfusionMatrixDisplay
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import CamembertTokenizer, CamembertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
import matplotlib.pyplot as plt

In [2]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce GTX 1080'

In [3]:
# Load the data
train_df = pd.read_csv("../data/train.csv")
train_df = train_df.set_index("id")
train_df.entities = train_df.entities.apply(json.loads)  # Parse entities
train_df.relations = train_df.relations.apply(json.loads)  # Parse relations

# Display the first few rows
train_df.head()

Unnamed: 0_level_0,text,entities,relations
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
181,"Anam Destresse, président de l'ONG ""Ma passion...","[{'id': 0, 'mentions': [{'value': 'accident', ...","[[0, STARTED_IN, 9], [7, IS_LOCATED_IN, 9], [5..."
31669,"À Paris, le 8 avril 2022, l'usine de déodorant...","[{'id': 0, 'mentions': [{'value': 'explosé', '...","[[9, IS_LOCATED_IN, 8], [11, OPERATES_IN, 8], ..."
51470,"En Espagne, dans une région agricole, une cont...","[{'id': 0, 'mentions': [{'value': 'contaminati...","[[7, IS_PART_OF, 8], [9, OPERATES_IN, 1], [0, ..."
51332,Un important incendie a fait des ravages dans ...,"[{'id': 0, 'mentions': [{'value': 'incendie', ...","[[12, IS_IN_CONTACT_WITH, 5], [0, IS_LOCATED_I..."
1131,« Je coule » : onze heures après avoir envoyé ...,"[{'id': 0, 'mentions': [{'value': 'renversé', ...","[[9, IS_LOCATED_IN, 2], [0, START_DATE, 17], [..."


In [4]:
# Extract all unique relation types
relation_types = set()
for relations in train_df.relations:
    for rel in relations:
        relation_types.add(rel[1])  # rel[1] is the relation type
relation_types = list(relation_types)
print(len(relation_types), "Relation types:", relation_types)

37 Relation types: ['START_DATE', 'END_DATE', 'IS_OF_NATIONALITY', 'INITIATED', 'IS_LOCATED_IN', 'HAS_FOR_LENGTH', 'IS_REGISTERED_AS', 'HAS_CONSEQUENCE', 'WEIGHS', 'INJURED_NUMBER', 'GENDER_FEMALE', 'HAS_FOR_WIDTH', 'HAS_CONTROL_OVER', 'HAS_LONGITUDE', 'IS_OF_SIZE', 'IS_COOPERATING_WITH', 'HAS_CATEGORY', 'HAS_FAMILY_RELATIONSHIP', 'HAS_LATITUDE', 'HAS_COLOR', 'IS_IN_CONTACT_WITH', 'GENDER_MALE', 'DEATHS_NUMBER', 'OPERATES_IN', 'HAS_QUANTITY', 'IS_PART_OF', 'CREATED', 'WAS_DISSOLVED_IN', 'RESIDES_IN', 'WAS_CREATED_IN', 'IS_AT_ODDS_WITH', 'IS_BORN_ON', 'HAS_FOR_HEIGHT', 'IS_DEAD_ON', 'STARTED_IN', 'IS_BORN_IN', 'DIED_IN']


In [5]:
ONTOLOGY_RELATIONS = [
    "HAS_CONTROL_OVER",
    "STARTED_IN",
    "IS_LOCATED_IN",
    "HAS_CATEGORY",
    "IS_PART_OF",
    "INJURED_NUMBER",
    "IS_OF_NATIONALITY",
    "OPERATES_IN",
    "INITIATED",
    "RESIDES_IN",
    "HAS_CONSEQUENCE",
    "IS_COOPERATING_WITH",
    "IS_IN_CONTACT_WITH",
    "IS_OF_SIZE",
    "HAS_QUANTITY",
    "HAS_FOR_LENGTH",
    "IS_BORN_IN",
    "WEIGHS",
    "HAS_FOR_WIDTH",
    "HAS_COLOR",
    "HAS_LATITUDE",
    "IS_REGISTERED_AS",
    "IS_AT_ODDS_WITH",
    "CREATED",
    "HAS_FAMILY_RELATIONSHIP",
    "DEATHS_NUMBER",
    "HAS_FOR_HEIGHT",
    "HAS_LONGITUDE",
    "IS_DEAD_ON",
    "START_DATE",
    "END_DATE",
    "WAS_CREATED_IN",
    "IS_BORN_ON",
    "WAS_DISSOLVED_IN",
    "DIED_IN",
    "GENDER_FEMALE",
    "GENDER_MALE",
]

In [6]:
print(len(ONTOLOGY_RELATIONS))

37


# 1 - Preprocess

In [7]:
# Function to mark entities in the text
def mark_entities(text, entities):
    marked_text = text
    for entity in entities:
        mention = entity["mentions"][0]
        value, start, end = mention["value"], mention["start"], mention["end"]
        marked_text = f"{marked_text[:start]}[E{entity['id']}]{value}[/E{entity['id']}]{marked_text[end:]}"
    return marked_text

# Prepare the dataset
def prepare_dataset(df, relation_types):
    data = []
    for _, row in df.iterrows():
        text = row["text"]
        entities = row["entities"]
        relations = row["relations"]
        marked_text = mark_entities(text, entities)
        for rel in relations:
            e1_id, rel_type, e2_id = rel
            data.append({"text": marked_text, "label": relation_types.index(rel_type)})
    return data

# Split the dataset
dataset = prepare_dataset(train_df, relation_types)
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

In [8]:
# Load CamemBERT tokenizer
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")

# Custom Dataset Class
class RelationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
        )
        return {
            "input_ids": encoding["input_ids"],
            "attention_mask": encoding["attention_mask"],
            "labels": item["label"],
        }

train_dataset = RelationDataset(train_data, tokenizer)
val_dataset = RelationDataset(val_data, tokenizer)

In [10]:
# Data collator for dynamic padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 2 - Train

In [11]:
# Load CamemBERT with a classification head
model = CamembertForSequenceClassification.from_pretrained(
    "camembert-base",
    num_labels=len(relation_types)
)

Some weights of CamembertForSequenceClassification were not initialized from the model checkpoint at camembert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Compute metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    # Overall metrics
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    
    # Per-class metrics
    class_report = classification_report(labels, predictions, target_names=relation_types, output_dict=True, zero_division=0)
    
    return {
        'accuracy': accuracy,
        'f1_weighted': f1,
        'precision_weighted': precision,
        'recall_weighted': recall,
        **{f'f1_{cls}': class_report[cls]['f1-score'] for cls in relation_types},  # Per-class F1
        **{f'precision_{cls}': class_report[cls]['precision'] for cls in relation_types},  # Per-class precision
        **{f'recall_{cls}': class_report[cls]['recall'] for cls in relation_types},  # Per-class recall
    }

In [13]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="../results",
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
)

# Trainer with data collator
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset, # Automatically shuffled
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [14]:
# Train the model
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

# 3 - Evaluate

In [None]:
def evaluate_per_class(model, dataloader, relation_types, device='cuda'):
    model = model.to(device)
    model.eval()
    predictions, labels = [], []

    # Process the dataset in batches
    for batch in dataloader:
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device),
            'labels': batch['labels'].to(device)
        }
        
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy())
        labels.extend(inputs['labels'].cpu().numpy())

    # Compute per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, labels=list(range(len(relation_types))), average=None, zero_division=0
    )

    # Create a DataFrame for better readability
    report_df = pd.DataFrame({
        'Class': relation_types,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1,
        'Support': support
    })

    # Generate confusion matrix
    plt.figure(figsize=(20, 8))
    cm = confusion_matrix(labels, predictions, labels=list(range(len(relation_types))))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=relation_types)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation=90)
    plt.title("Confusion Matrix")
    plt.show()

    return report_df

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=data_collator,
)

# Evaluate per-class performance on the validation set
report_df = evaluate_per_class(model, val_loader, relation_types)

# Print the report
print("Per-Class Performance Report:")
print(report_df)

In [None]:
# Save the model
model.save_pretrained("../camembert-relation-extraction")
tokenizer.save_pretrained("../camembert-relation-extraction")

# 4 - Predict

In [None]:
# Load the model
model = CamembertForSequenceClassification.from_pretrained("../camembert-relation-extraction")
tokenizer = CamembertTokenizer.from_pretrained("../camembert-relation-extraction")

# Example prediction
test_text = "[E1]Anam Destresse[/E1] a été blessé dans un accident impliquant un [E2]bus[/E2]."
inputs = tokenizer(test_text, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
predicted_label = outputs.logits.argmax(dim=-1).item()
predicted_relation = relation_types[predicted_label]

print(f"Predicted relation: {predicted_relation}")

Class Weights:
Use class weights in the loss function to penalize misclassifications in minority classes more heavily.

Oversampling/Downsampling:
Oversample minority classes or downsample majority classes during training.

Data Augmentation:
Augment data for minority classes to balance the dataset.