This is a sample training of a temporal pairwise relation classifier, the binary module for the already-pre-trained is furnished in models.

In [None]:
#Imports
import os
import glob
import time
import re
from collections import Counter

import numpy as np
from bs4 import BeautifulSoup, NavigableString, Tag

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

from transformers import (
    RobertaTokenizer,
    RobertaForSequenceClassification,
)

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
)

from tqdm.auto import tqdm

print("Torch:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

#Binary-label mapping
LABEL_MAP = {
    "BEFORE": 0,
    "AFTER": 1
}

ID2LABEL = {v: k for k, v in LABEL_MAP.items()}
NUM_LABELS = len(LABEL_MAP)

print("Binary label map:", LABEL_MAP)


In [None]:
#Paths
BASE_PATH = "insert-path"
MODEL_SAVE_PATH = os.path.join(
    BASE_PATH,
    "temporal_models",
    "roberta_matres_binary",
)

os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

print("PATHS", BASE_PATH, MODEL_SAVE_PATH)


In [None]:
#Clone MATRES and StructTempRel repos if needed

import subprocess

def clone_if_needed(repo_url: str, target_dir: str):
    if os.path.exists(target_dir):
        print(f"[SKIP] {target_dir} already exists")
        return
    print(f"[CLONE] {repo_url} → {target_dir}")
    result = subprocess.run(
        ["git", "clone", repo_url, target_dir],
        capture_output=True,
        text=True,
    )
    if result.returncode == 0:
        print(f"cloned {target_dir}")
    else:
        print(f"clone failed for {target_dir}")
        print(result.stderr)


print("=" * 60)
print("Cloning repositories (if needed)...")
print("=" * 60)

clone_if_needed("https://github.com/qiangning/MATRES.git", "MATRES")
clone_if_needed("https://github.com/qiangning/StructTempRel-EMNLP17.git",
                "StructTempRel-EMNLP17")

matres_files = [
    "MATRES/timebank.txt",
    "MATRES/aquaint.txt",
    "MATRES/platinum.txt",
]

print("\nVerifying MATRES files:")
for f in matres_files:
    ok = os.path.exists(f)
    print(f"  {'[OK]' if ok else '[MISSING]'} {f}")


TEMPEVAL_ROOT = "StructTempRel-EMNLP17/data/TempEval3"

tml_files = glob.glob(os.path.join(TEMPEVAL_ROOT, "**", "*.tml"), recursive=True)
print(f"  Found {len(tml_files)} .tml files")

if tml_files:
    print("  Sample files:")
    for fp in sorted(tml_files)[:5]:
        print("   •", fp)
else:
    print("not found")

In [None]:

"""
Parse *gold* TempEval3 .tml files from:
    StructTempRel-EMNLP17/data/TempEval3/**/*.tml

We extract:
- Full document text (TEXT)
- EVENT texts and their character spans
- eiid -> eid -> event mapping
- numeric eiid -> {'text', 'eid', 'eiid_str', 'start', 'end'}
"""

from bs4 import BeautifulSoup, NavigableString, Tag
import glob
import re


TEMPEVAL_ROOT = "StructTempRel-EMNLP17/data/TempEval3"

def parse_tempeval3_file(filepath):
    """
    Robust parser using tree traversal, so char offsets are exact even
    when event text repeats.

    Returns:
        {
          'doc_id': str,
          'text': str,
          'events': {
              numeric_eiid: {
                  'text': str,
                  'eid': str,
                  'eiid_str': str,
                  'start': int,
                  'end': int,
              }, ...
          }
        }
        or None if TEXT is missing.
    """
    with open(filepath, "r", encoding="utf-8") as f:
        content = f.read()

    soup = BeautifulSoup(content, "xml")
    doc_id = os.path.basename(filepath).replace(".tml", "")

    text_element = soup.find("TEXT")
    if text_element is None:
        return None

    full_text_parts = []
    eid_to_text = {}
    eid_to_pos = {}
    current_pos = 0

    def process_node(node):
        nonlocal current_pos

        if isinstance(node, NavigableString):
            txt = str(node)
            full_text_parts.append(txt)
            current_pos += len(txt)
            return

        if isinstance(node, Tag):
            if node.name == "EVENT":
                eid = node.get("eid")
                if not eid:
                    # no eid, just recurse
                    for child in node.children:
                        process_node(child)
                    return

                event_text = node.get_text()
                eid_to_text[eid] = event_text

                start = current_pos
                full_text_parts.append(event_text)
                current_pos += len(event_text)
                end = current_pos
                eid_to_pos[eid] = (start, end)
            else:
                for child in node.children:
                    process_node(child)

    # Traverse TEXT subtree in order to reconstruct text + spans
    for child in text_element.children:
        process_node(child)

    full_text = "".join(full_text_parts)

    # Map eiid -> eid from MAKEINSTANCE
    eiid_to_eid = {}
    for mi in soup.find_all("MAKEINSTANCE"):
        eiid = mi.get("eiid")      # e.g. "ei1191"
        eid  = mi.get("eventID")   # e.g. "e433"
        if eiid and eid:
            eiid_to_eid[eiid] = eid

    # Build numeric eiid -> event info
    events = {}
    for eiid_str, eid in eiid_to_eid.items():
        if eid not in eid_to_text:
            continue

        m = re.search(r"\d+", eiid_str)
        if not m:
            continue
        eiid_num = int(m.group())

        start, end = eid_to_pos.get(eid, (-1, -1))
        events[eiid_num] = {
            "text": eid_to_text[eid],
            "eid": eid,
            "eiid_str": eiid_str,
            "start": start,
            "end": end,
        }

    return {
        "doc_id": doc_id,
        "text": full_text,
        "events": events,
    }


def parse_tempeval3_corpus(root_dir):
    """
    Parse all .tml files under root_dir (recursively).
    """
    tml_files = glob.glob(os.path.join(root_dir, "**", "*.tml"), recursive=True)
    print(f"  Found {len(tml_files)} .tml files under {root_dir}")

    tempeval_docs = {}
    tempeval_docs_normalized = {}
    total_events = 0
    errors = 0

    for fp in tml_files:
        try:
            doc = parse_tempeval3_file(fp)
        except Exception as e:
            errors += 1
            continue

        if doc is None:
            continue

        doc_id = doc["doc_id"]
        tempeval_docs[doc_id] = doc
        total_events += len(doc["events"])

        norm_id = doc_id.lower().replace("-", "").replace("_", "").replace(".", "")
        tempeval_docs_normalized[norm_id] = doc

    print("Parsed TempEval3 corpus")

    return tempeval_docs, tempeval_docs_normalized


tempeval_docs, tempeval_docs_normalized = parse_tempeval3_corpus(TEMPEVAL_ROOT)


In [None]:
#Load MATRES file

def extract_event_context(doc_data, event1, event2, window_chars=200):
    """
    Build a local text window around both events using gold char offsets.
    """
    text = doc_data["text"]
    n = len(text)

    s1, e1 = event1["start"], event1["end"]
    s2, e2 = event2["start"], event2["end"]

    if s1 < 0 or s2 < 0:
        # fallback: full document (rare if XML is consistent)
        return text, s1, s2

    span_min = min(s1, s2)
    span_max = max(e1, e2)

    ctx_start = max(0, span_min - window_chars)
    ctx_end = min(n, span_max + window_chars)

    context = text[ctx_start:ctx_end]
    e1_local = s1 - ctx_start
    e2_local = s2 - ctx_start
    return context, e1_local, e2_local


def load_matres_binary(filepath, tempeval_docs, tempeval_docs_normalized):
    """
    Load a MATRES file and keep ONLY BEFORE / AFTER relations.

    Returns a list of examples:
        {
          'doc_id', 'event1_text', 'event2_text',
          'context', 'e1_pos', 'e2_pos',
          'label', 'label_name'
        }
    """
    examples = []
    total = miss_doc = miss_event = bad_context = 0

    with open(filepath, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) != 6:
                continue

            doc_id, _, _, eiid1_str, eiid2_str, relation = parts
            total += 1

            if relation not in LABEL_MAP:
                # skip EQUAL, VAGUE, etc. for binary setup
                continue

            # Get TempEval doc
            doc = tempeval_docs.get(doc_id)
            if doc is None:
                norm = doc_id.lower().replace("-", "").replace("_", "").replace(".", "")
                doc = tempeval_docs_normalized.get(norm)
            if doc is None:
                miss_doc += 1
                continue

            # MATRES uses numeric eiids
            try:
                eiid1 = int(eiid1_str)
                eiid2 = int(eiid2_str)
            except ValueError:
                miss_event += 1
                continue

            events = doc["events"]
            if eiid1 not in events or eiid2 not in events:
                miss_event += 1
                continue

            e1 = events[eiid1]
            e2 = events[eiid2]

            context, e1_pos, e2_pos = extract_event_context(doc, e1, e2, window_chars=200)

            # sanity: both texts must appear somewhere in context
            if e1["text"] not in context or e2["text"] not in context:
                bad_context += 1
                continue

            examples.append(
                {
                    "doc_id": doc_id,
                    "event1_text": e1["text"],
                    "event2_text": e2["text"],
                    "context": context,
                    "e1_pos": e1_pos,
                    "e2_pos": e2_pos,
                    "label": LABEL_MAP[relation],
                    "label_name": relation,
                }
            )

    print(f"Loaded {filepath}")
    print(f"  Total rows:      {total}")
    print(f"  Valid examples:  {len(examples)}")
    print(f"  Missing docs:    {miss_doc}")
    print(f"  Missing events:  {miss_event}")
    print(f"  Bad contexts:    {bad_context}")
    return examples


timebank_examples  = load_matres_binary("MATRES/timebank.txt",  tempeval_docs, tempeval_docs_normalized)
aquaint_examples   = load_matres_binary("MATRES/aquaint.txt",   tempeval_docs, tempeval_docs_normalized)
platinum_examples  = load_matres_binary("MATRES/platinum.txt",  tempeval_docs, tempeval_docs_normalized)

train_dev_examples = timebank_examples + aquaint_examples
test_examples      = platinum_examples

from sklearn.model_selection import train_test_split

train_examples, dev_examples = train_test_split(
    train_dev_examples,
    test_size=0.1,
    random_state=42,
    stratify=[ex["label"] for ex in train_dev_examples],
)

In [None]:
#Dataset with <e1>/<e2> markers

class MATRESDatasetWithMarkers(Dataset):
    def __init__(self, examples, tokenizer, max_length=256):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]

        context = ex["context"]
        e1_text = ex["event1_text"]
        e2_text = ex["event2_text"]
        e1_pos  = ex["e1_pos"]
        e2_pos  = ex["e2_pos"]

        marked = self._insert_markers(context, e1_text, e2_text, e1_pos, e2_pos)

        encoded = self.tokenizer(
            marked,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "labels": torch.tensor(ex["label"], dtype=torch.long),
        }

    def _insert_markers(self, context, e1, e2, p1, p2):
        """
        Use char offsets if they match; otherwise fall back to first occurrence.
        """

        def valid(pos, text):
            return (
                pos is not None
                and pos >= 0
                and pos + len(text) <= len(context)
                and context[pos:pos+len(text)] == text
            )

        ok1 = valid(p1, e1)
        ok2 = valid(p2, e2)

        if ok1 and ok2:
            # use offsets
            if p1 < p2:
                e1_end = p1 + len(e1)
                e2_end = p2 + len(e2)
                return (
                    context[:p1]
                    + f"<e1> {e1} </e1>"
                    + context[e1_end:p2]
                    + f"<e2> {e2} </e2>"
                    + context[e2_end:]
                )
            else:
                e2_end = p2 + len(e2)
                e1_end = p1 + len(e1)
                return (
                    context[:p2]
                    + f"<e2> {e2} </e2>"
                    + context[e2_end:p1]
                    + f"<e1> {e1} </e1>"
                    + context[e1_end:]
                )

        # fallback: naive replacement
        marked = context
        if e1 in marked:
            marked = marked.replace(e1, f"<e1> {e1} </e1>", 1)
        if e2 in marked:
            marked = marked.replace(e2, f"<e2> {e2} </e2>", 1)
        return marked


from transformers import RobertaTokenizer


tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})

train_dataset = MATRESDatasetWithMarkers(train_examples, tokenizer)
dev_dataset   = MATRESDatasetWithMarkers(dev_examples, tokenizer)
test_dataset  = MATRESDatasetWithMarkers(test_examples, tokenizer)

print("Dataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Dev:   {len(dev_dataset)}")
print(f"  Test:  {len(test_dataset)}")


In [None]:
#Models Setup -> Using RobertaForSequenceClassification, not necessary, just in line with tempeval model

from transformers import RobertaForSequenceClassification

model = RobertaForSequenceClassification.from_pretrained(
    "roberta-base",
    num_labels=NUM_LABELS,   # 2
)
model.resize_token_embeddings(len(tokenizer))
model.to(device)

print("Model on device:", device)
print("  num_labels:", NUM_LABELS)
print("  vocab size:", len(tokenizer))

from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader   = DataLoader(dev_dataset,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

print("\nDataLoaders:")
print("  train batches:", len(train_loader))
print("  dev batches:  ", len(dev_loader))
print("  test batches: ", len(test_loader))

# Class weights
train_labels = np.array([ex["label"] for ex in train_examples])
unique = np.array(sorted(set(train_labels)))

from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight("balanced", classes=unique, y=train_labels)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float, device=device)

print("\nClass distribution:")
from collections import Counter
counts = Counter(train_labels.tolist())
for lab in unique:
    c = counts[lab]
    pct = 100.0 * c / len(train_labels)
    print(f"  {ID2LABEL[lab]:6s}: {c:4d} ({pct:5.1f}%) → weight={class_weights[lab]:.3f}")

loss_fn = CrossEntropyLoss(weight=class_weights_tensor)
print("\nUsing weighted CrossEntropyLoss.")


In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm.auto import tqdm

EPOCHS = 10
LR = 2e-5
WARMUP_RATIO = 0.1

optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    return max(0.0, float(total_steps - step) / max(1, total_steps - warmup_steps))

scheduler = LambdaLR(optimizer, lr_lambda)

print("=" * 60)
print("Training configuration:")
print(f"  Epochs:       {EPOCHS}")
print(f"  Batch size:   {BATCH_SIZE}")
print(f"  LR:           {LR}")
print(f"  Total steps:  {total_steps}")
print(f"  Warmup steps: {warmup_steps}")
print("=" * 60)

best_dev_f1 = 0.0
start_time = time.time()

for epoch in range(EPOCHS):
    print("\n" + "=" * 60)
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print("=" * 60)

    # Training
    model.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        outputs = model(
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
        )
        logits = outputs.logits
        loss = loss_fn(logits, batch["labels"].to(device))

        train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    avg_loss = train_loss / len(train_loader)
    print(f"  Avg train loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    dev_preds, dev_golds = [], []

    with torch.no_grad():
        for batch in tqdm(dev_loader, desc="Evaluating"):
            outputs = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            dev_preds.extend(preds.cpu().numpy())
            dev_golds.extend(batch["labels"].cpu().numpy())

    dev_preds = np.array(dev_preds)
    dev_golds = np.array(dev_golds)

    acc = accuracy_score(dev_golds, dev_preds)
    prec, rec, f1, _ = precision_recall_fscore_support(
        dev_golds, dev_preds, average="macro", zero_division=0
    )
    _, _, per_class_f1, _ = precision_recall_fscore_support(
        dev_golds, dev_preds, average=None, zero_division=0
    )

    print(f"  Dev accuracy : {acc:.4f}")
    print(f"  Dev macro F1 : {f1:.4f}")
    print("  Per-class F1 :")
    print(f"    BEFORE: {per_class_f1[0]:.4f}")
    print(f"    AFTER : {per_class_f1[1]:.4f}")

    if f1 > best_dev_f1:
        best_dev_f1 = f1
        print(f"New best macro F1: {f1:.4f} – saving model...")
        model.save_pretrained(MODEL_SAVE_PATH)
        tokenizer.save_pretrained(MODEL_SAVE_PATH)

elapsed = (time.time() - start_time) / 60.0
print("\n" + "=" * 60)
print("Training complete.")
print(f"  Total time: {elapsed:.2f} minutes")
print(f"  Best dev macro F1: {best_dev_f1:.4f}")
print("=" * 60)


In [None]:
#Test evaluation on MATRES Platinum
from sklearn.metrics import classification_report

print("=" * 60)
print("Evaluating best model on MATRES Platinum (test set)...")
print("=" * 60)

best_model = RobertaForSequenceClassification.from_pretrained(MODEL_SAVE_PATH)
best_model.resize_token_embeddings(len(tokenizer))
best_model.to(device)
best_model.eval()

test_preds, test_golds = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        outputs = best_model(
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
        )
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)

        test_preds.extend(preds.cpu().numpy())
        test_golds.extend(batch["labels"].cpu().numpy())

test_preds = np.array(test_preds)
test_golds = np.array(test_golds)

acc = accuracy_score(test_golds, test_preds)
prec, rec, f1, _ = precision_recall_fscore_support(
    test_golds, test_preds, average="macro", zero_division=0
)
_, _, per_class_f1, _ = precision_recall_fscore_support(
    test_golds, test_preds, average=None, zero_division=0
)

print("\nTEST RESULTS (binary BEFORE/AFTER):")
print("=" * 60)
print(f"  Accuracy : {acc:.4f}")
print(f"  Macro F1 : {f1:.4f}")
print(f"  Precision: {prec:.4f}")
print(f"  Recall   : {rec:.4f}")
print("\n  Per-class F1:")
print(f"    BEFORE: {per_class_f1[0]:.4f}")
print(f"    AFTER : {per_class_f1[1]:.4f}")

print("\nClassification report:")
print(classification_report(
    test_golds,
    test_preds,
    target_names=["BEFORE", "AFTER"],
    digits=4,
    zero_division=0,
))


In [None]:
BEST_MODEL_DIR = MODEL_SAVE_PATH  # Already saved epoch with best F1

print("Model saved to:", BEST_MODEL_DIR)