A notebook for data manipulation before contrastive training.

## Data and Libraries

In [8]:
import pandas as pd
import json
import re

import seaborn as sns
import matplotlib.pyplot as plt

import numpy as np
from tqdm import tqdm

from sklearn.metrics.pairwise import cosine_similarity

from collections import defaultdict

#### Helper functions:

In [9]:

def load_rel_dict(path):

    with open(path) as f:
        data = f.read()

    rel_dict = {}
    for r in data.split("\n")[1:-1]:
        rel_id, rel = re.split(r"\((\d+)\)",r)[1:]
        rel_dict[int(rel_id)] = rel.strip()

    return rel_dict




def parse_labelstudio(path: str, relation_mapping) -> pd.DataFrame:
    with open(path, "r", encoding="utf-8") as file:
        dataset = json.load(file)
        
    ls_tuples = []    
    for entry in dataset:
        annotations = entry.get("annotations", [])
        data = entry.get("data", {})

        sentence = data.get("sentence", "")
        paper_id = data.get("paper_id")
        sentence_id = data.get("sentence_id")
        # entity_types_mapping = {e["id"]: e["value"]["labels"] for e in annotation.get("result", []) if e["type"] == "labels"}
        
        for annotation in annotations:
            results = annotation.get("result", [])
            entities = {e["id"]: e["value"] for e in results if e["type"] == "labels"}
            relations = [r for r in results if r["type"] == "relation"]

            for relation in relations:
                from_id = relation["from_id"]
                to_id = relation["to_id"]
                direction = relation["direction"]
                relation_type = relation.get("labels", [""])[0] or "0"
                relation_name = relation_mapping.get(int(relation_type), "")

                # Get entity spans
                e1_value = entities.get(from_id if direction == "right" else to_id, {})
                e2_value = entities.get(to_id if direction == "right" else from_id, {})

                e1_text = e1_value.get("text", "")
                e2_text = e2_value.get("text", "")

                e1_type = e1_value.get("labels", "")[0]
                e2_type = e2_value.get("labels", "")[0]
                
                e1_start_pos = e1_value.get("start", 0)
                e1_end_pos = e1_value.get("end", 0)
                
                e2_start_pos = e2_value.get("start", 0)
                e2_end_pos = e2_value.get("end", 0)

                ls_tuples.append({
                    'sent': sentence,
                    'r': relation_name,
                    'e1': e1_text,
                    'e1_t': e1_type,
                    'e1_start_pos': e1_start_pos,
                    'e1_end_pos': e1_end_pos,
                    'e2': e2_text,
                    'e2_t': e2_type,
                    'e2_start_pos': e2_start_pos,
                    'e2_end_pos': e2_end_pos,
                    'paper_id': paper_id,
                    'sentence_id': sentence_id
                })

                # If bidirectional, add the reverse too
                if direction == "undirected":
                    ls_tuples.append({
                        'sent': sentence,
                        'r': relation_name,
                        'e1': e2_text,
                        'e1_t': e2_type,
                        'e1_start_pos': e2_start_pos,
                        'e1_end_pos': e2_end_pos,
                        'e2': e1_text,
                        'e2_t': e1_type,
                        'e2_start_pos': e1_start_pos,
                        'e2_end_pos': e1_end_pos,
                        'paper_id': paper_id,
                        'sentence_id': sentence_id
                    })

    return pd.DataFrame(ls_tuples)


#### Data

In [10]:
ls_filepath = "DATA/project-6-at-2025-06-24-06-31-0c74618b.json"
rel_dictpath = "/home/p0l3/RAD/DROP/RELdata_simple.txt"

In [11]:
relation_mapping = load_rel_dict(rel_dictpath)
df = parse_labelstudio(ls_filepath, relation_mapping)

## Contrastive Dataset Creation

**Note:** For a more in-depth analysis of the data, go to original notebook: [relation_emb_stats](https://github.com/P0L3/rel_dis/blob/main/REL_DIS/relation_emb_stats.ipynb).

#### Helper Functions

In [13]:
def compute_pos_level_1_pairs(df):
    """Hardest Positive: Same r, same e1_t, same e2_t."""
    grouped = df.groupby(['r', 'e1_t', 'e2_t'])
    pairs = defaultdict(list)
    for _, group in grouped:
        indices = group.index.tolist()
        if len(indices) > 1:
            for i in indices:
                pairs[i].extend([j for j in indices if i != j])
    return dict(pairs)

def compute_pos_level_2_pairs(df):
    """Medium Positive: Same r, one entity type match (XOR)."""
    grouped = df.groupby('r')
    pairs = defaultdict(list)
    for _, group in grouped:
        for i, row_i in group.iterrows():
            for j, row_j in group.iterrows():
                if i == j: continue
                if (row_i['e1_t'] == row_j['e1_t']) ^ (row_i['e2_t'] == row_j['e2_t']):
                    pairs[i].append(j)
    return dict(pairs)

def compute_pos_level_3_pairs(df):
    """Easiest Positive: Same r, different entity types."""
    grouped = df.groupby('r')
    pairs = defaultdict(list)
    for _, group in grouped:
        for i, row_i in group.iterrows():
            for j, row_j in group.iterrows():
                if i == j: continue
                if (row_i['e1_t'] != row_j['e1_t']) and (row_i['e2_t'] != row_j['e2_t']):
                    pairs[i].append(j)
    return dict(pairs)
    
def compute_neg_level_3_pairs(df):
    """Easiest Negative: Different r, same entity types."""
    df_with_idx = df.reset_index()
    merged = pd.merge(df_with_idx, df_with_idx, on=['e1_t', 'e2_t'])
    filtered = merged[(merged.r_x != merged.r_y) & (merged.index_x != merged.index_y)]
    return filtered.groupby('index_x')['index_y'].apply(list).to_dict()

def compute_neg_level_2_pairs(df):
    """Medium Negative: Different r, one entity type match."""
    df_with_idx = df.reset_index().rename(columns={'index': 'idx'})
    m1 = pd.merge(df_with_idx, df_with_idx, on='e1_t')
    cond1 = (m1.r_x != m1.r_y) & (m1.e2_t_x != m1.e2_t_y) & (m1.idx_x != m1.idx_y)
    p1 = m1[cond1][['idx_x', 'idx_y']]
    m2 = pd.merge(df_with_idx, df_with_idx, on='e2_t')
    cond2 = (m2.r_x != m2.r_y) & (m2.e1_t_x != m2.e1_t_y) & (m2.idx_x != m2.idx_y)
    p2 = m2[cond2][['idx_x', 'idx_y']]
    all_pairs = pd.concat([p1, p2]).drop_duplicates()
    return all_pairs.groupby('idx_x')['idx_y'].apply(list).to_dict()

def compute_neg_level_1_pairs(df):
    """Hardest Negative: Different r, different entity types."""
    df_with_idx = df.reset_index().rename(columns={'index': 'idx'})
    df_with_idx['key'] = 1
    merged = pd.merge(df_with_idx, df_with_idx, on='key').drop('key', axis=1)
    condition = (
        (merged.r_x != merged.r_y) &
        (merged.e1_t_x != merged.e1_t_y) &
        (merged.e2_t_x != merged.e2_t_y) &
        (merged.idx_x != merged.idx_y)
    )
    filtered = merged[condition]
    return filtered.groupby('idx_x')['idx_y'].apply(list).to_dict()

### Dataframe construction

In [14]:

print("Computing pairs for all 6 levels...")

level_funcs = {
    'pos_level_1': compute_pos_level_1_pairs,
    'pos_level_2': compute_pos_level_2_pairs,
    'pos_level_3': compute_pos_level_3_pairs,
    'neg_level_1': compute_neg_level_1_pairs,
    'neg_level_2': compute_neg_level_2_pairs,
    'neg_level_3': compute_neg_level_3_pairs,
}

for col_name, func in level_funcs.items():
    print(f"  - Computing {col_name}...")
    pair_dict = func(df)
    
    # --- CORRECTED LOGIC ---
    # Step 1: Create the new column. It will contain lists and NaN values.
    # pandas correctly aligns the dictionary keys with the DataFrame index.
    df[col_name] = df.index.map(pair_dict)
    
    # Create the series which contains lists and NaN values
    series_with_nans = df.index.map(pair_dict)

    # Step 2: Now that df[col_name] is a Series, apply the function to it
    df[col_name] = [x if isinstance(x, list) else [] for x in series_with_nans]

print("\n--- Augmented DataFrame ---")
# To display the lists properly
pd.set_option('display.max_colwidth', 100)

# Display the relevant columns
display_cols = ['sent'] + list(level_funcs.keys())

Computing pairs for all 6 levels...
  - Computing pos_level_1...
  - Computing pos_level_2...
  - Computing pos_level_3...
  - Computing neg_level_1...
  - Computing neg_level_2...
  - Computing neg_level_3...

--- Augmented DataFrame ---


In [17]:
df_path = ls_filepath.replace(".json", "_df.pickle")
df.to_pickle(df_path)

## Data Specification

### Data Specification: Pre-computed Contrastive Pairs for Relation Extraction

**Version:** 1.2
**Date:** June 25, 2025
**Contact:** Andrija Poleksić (andrija.poleksic@uniri.hr)

#### 1. Overview

This document specifies the format for a dataset designed for contrastive learning of sentence representations for relation extraction. The dataset is delivered as a single tabular file (e.g., CSV, Parquet) that can be loaded into a pandas DataFrame.

The primary unit of this dataset is a **relation instance**. Each row represents a single semantic relation (`r`) between two entities (`e1`, `e2`)—with their precise character-level positions—found within a source sentence (`sent`).

A key feature of this dataset is the pre-computation of six columns (`pos_level_1` through `neg_level_3`). These columns contain lists of indices pointing to other relation instances, categorized by similarity. This structure is purpose-built to facilitate advanced training strategies, such as weighted contrastive loss, where the "hardness" of a pair determines the strength of the training signal.

#### 2. File Format & Schema

The data is structured as a table, typically stored in a `.csv` or `.parquet` file. When loaded, it forms a pandas DataFrame where the **index of the DataFrame serves as the unique ID for each relation instance**.

**Schema Definition:**

| Column Name     | Data Type             | Description                                                                                                                              | Example from Data                                                              |
| :-------------- | :-------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- | :----------------------------------------------------------------------------- |
| `index`         | Integer (Implicit)    | The unique identifier for each relation instance. All `pos_` and `neg_` level columns contain lists of these index values.                  | `0`                                                                            |
| `sent`          | String                | The full source sentence from which the relation was extracted. May be duplicated across rows if a sentence contains multiple relations. | `"The marine algal toxin domoic acid is an important threat..."`              |
| `r`             | String                | The relation phrase or label connecting the two entities.                                                                                | `"is a threat"`                                                                |
| `e1`            | String                | The text of the first entity (subject) in the relation.                                                                                  | `"domoic acid"`                                                                |
| `e1_t`          | String                | The designated type of the first entity.                                                                                                 | `"Chemical"`                                                                   |
| `e1_start_pos`| Integer             | The start character index of `e1` within `sent` (inclusive).                                                                         | `23`                                                                       |
| `e1_end_pos`  | Integer             | The end character index of `e1` within `sent` (exclusive).                                                                           | `34`                                                                       |
| `e2`            | String                | The text of the second entity (object) in the relation.                                                                                  | `"marine mammal health"`                                                       |
| `e2_t`          | String                | The designated type of the second entity.                                                                                                | `"Other"`                                                                      |
| `e2_start_pos`| Integer             | The start character index of `e2` within `sent` (inclusive).                                                                         | `61`                                                                       |
| `e2_end_pos`  | Integer             | The end character index of `e2` within `sent` (exclusive).                                                                           | `81`                                                                       |
| `paper_id`      | Integer               | An identifier for the source document or paper.                                                                                          | `10167`                                                                        |
| `sentence_id`   | Integer               | An identifier for the sentence within its source document.                                                                               | `0`                                                                            |
| `pos_level_1`   | List of Integers      | A list of indices pointing to "Hardest Positive" examples.                                                                               | `[449]`                                                                        |
| `pos_level_2`   | List of Integers      | A list of indices pointing to "Medium Positive" examples.                                                                                | `[1202]`                                                                       |
| `pos_level_3`   | List of Integers      | A list of indices pointing to "Easiest Positive" examples.                                                                               | `[2017]`                                                                       |
| `neg_level_1`   | List of Integers      | A list of indices pointing to "Easiest Negative" examples.                                                                               | `[5, 8, 9, 10, ...]`                                                           |
| `neg_level_2`   | List of Integers      | A list of indices pointing to "Medium Negative" examples.                                                                                | `[1, 2, 3, 4, ...]`                                                            |
| `neg_level_3`   | List of Integers      | A list of indices pointing to "Hardest Negative" examples.                                                                               | `[11, 52, 119, ...]`                                                           |

---

#### 3. Definitions of Similarity Levels

The categorization of positive and negative levels determines the desired strength of the training signal. **"Harder" examples are those that should be pulled closest (for positives) or pushed furthest away (for negatives) during training.**

##### 3.1. Positive Levels (Same Relation)
A "positive" pair always shares the **same `r`** (relation) as the anchor. The hardness level defines how close their embeddings should become.

| Level         | Condition for Inclusion                                          | Hardness Interpretation                                                                |
| :------------ | :--------------------------------------------------------------- | :------------------------------------------------------------------------------------- |
| `pos_level_1` | **Same** `r`, **Same** `e1_t`, **Same** `e2_t`                     | **Hardest Positive**. Structurally identical; should be pulled very close in the embedding space. |
| `pos_level_2` | **Same** `r`, one matching entity type, one different entity type. | **Medium Positive**. Shares a core context; should be pulled moderately close.         |
| `pos_level_3` | **Same** `r`, **Different** `e1_t`, **Different** `e2_t`           | **Easiest Positive**. Shares only the abstract relation; requires a standard "pull" force. |

##### 3.2. Negative Levels (Different Relation)
A "negative" pair always has a **different `r`** (relation) from the anchor. The hardness level defines how far their embeddings should be pushed apart.

| Level         | Condition for Inclusion                                            | Hardness Interpretation                                                              |
| :------------ | :----------------------------------------------------------------- | :----------------------------------------------------------------------------------- |
| `neg_level_1` | **Different** `r`, **Same** `e1_t`, **Same** `e2_t`                  | **Easiest Negative**. Semantically confusable; requires a small "push" just beyond the decision margin. |
| `neg_level_2` | **Different** `r`, one matching entity type, one different entity type. | **Medium Negative**. Represents a moderate level of dissimilarity, requiring an intermediate push-away force. |
| `neg_level_3` | **Different** `r`, **Different** `e1_t`, **Different** `e2_t`        | **Hardest Negative**. Structurally and semantically very different; should be pushed furthest away. |

---

#### 4. Example Row

| Field           | Value                                                                                         |
| :-------------- | :-------------------------------------------------------------------------------------------- |
| **index**       | `0`                                                                                           |
| **sent**        | `"The marine algal toxin domoic acid is an important threat to marine mammal health..."`        |
| **r**           | `"is a threat"`                                                                               |
| **e1**          | `"domoic acid"`                                                                               |
| **e1_t**        | `"Chemical"`                                                                                  |
| **e1_start_pos**| `23`                                                                                          |
| **e1_end_pos**  | `34`                                                                                          |
| **e2**          | `"marine mammal health"`                                                                      |
| **e2_t**        | `"Other"`                                                                                     |
| **e2_start_pos**| `61`                                                                                          |
| **e2_end_pos**  | `81`                                                                                          |
| **paper_id**    | `10167`                                                                                       |
| **sentence_id** | `0`                                                                                           |
| **pos_level_1** | `[449]`                                                                                       |
| **pos_level_2** | `[]`                                                                                          |
| **pos_level_3** | `[2017]`                                                                                      |
| **neg_level_1** | `[5, 8, 9, 10, ...]`                                                                          |
| **neg_level_2** | `[1, 2, 3, 4, ...]`                                                                           |
| **neg_level_3** | `[11, 52, 119, ...]`                                                                          |