# PsyCoMark Multi-task Training: Marker Types + Conspiracy Detection

Joint training with DistilRoBERTa for:
- Token classification (BIO tags for marker types)
- Document classification (conspiracy: Yes/No/Can't tell)


In [None]:
# Environment setup (run first)
import os
os.environ["TRANSFORMERS_NO_TORCHVISION"] = "1"   # avoid torchvision import path
os.environ["TOKENIZERS_PARALLELISM"] = "false"    # quieter logs
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"


In [None]:
# Install dependencies
!pip -q install transformers==4.46.1 datasets==2.20.0 accelerate==0.34.2 seqeval==1.2.2 scikit-learn


In [None]:
# Imports
import json, random, os, math
from collections import defaultdict, Counter
from typing import List, Dict, Any, Tuple

import torch
import numpy as np
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoConfig, AutoModel, Trainer, TrainingArguments, DataCollatorWithPadding
)
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.metrics import f1_score as sk_f1_score, classification_report as sk_classification_report

print("torch:", torch.__version__)
import transformers, datasets
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("CUDA available:", torch.cuda.is_available())


In [None]:
# Configuration
SEED = 42
MODEL_NAME = "distilroberta-base"
MAX_LEN = 512
DOC_LABELS = ["No", "Yes", "Can't tell"]  # order matters
MARKER_TYPES = ["Actor", "Action", "Victim", "Evidence", "Effect"]
BIO_LABELS = ["O"] + [f"{p}-{t}" for t in MARKER_TYPES for p in ["B","I"]]

# Loss weights - boost token learning
LAMBDA_TOKEN = 3.0  # Increased from 1.0
LAMBDA_DOC = 1.0

BATCH_SIZE = 16
EPOCHS = 4
LR = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

print(f"BIO labels ({len(BIO_LABELS)}): {BIO_LABELS}")
print(f"Doc labels ({len(DOC_LABELS)}): {DOC_LABELS}")
print(f"Loss weights: token={LAMBDA_TOKEN}, doc={LAMBDA_DOC}")


In [None]:
# Load and split data
DATA_PATH = "/content/train_rehydrated.jsonl"
assert os.path.exists(DATA_PATH), "Upload train_rehydrated.jsonl to /content first."

rows = [json.loads(l) for l in open(DATA_PATH, "r", encoding="utf-8")]
print(f"Total rows: {len(rows)}")

# Stratified 70/15/15 split by 'conspiracy'
by_label = defaultdict(list)
for r in rows:
    by_label[r["conspiracy"]].append(r)

splits = {"train": [], "val": [], "test": []}
for lbl, items in by_label.items():
    random.shuffle(items)
    n = len(items)
    n_train = int(0.70 * n)
    n_val = int(0.15 * n)
    train, val, test = items[:n_train], items[n_train:n_train+n_val], items[n_train+n_val:]
    splits["train"].extend(train)
    splits["val"].extend(val)
    splits["test"].extend(test)

def summarize_split(name, data):
    c = Counter(r["conspiracy"] for r in data)
    print(f"{name}: n={len(data)}  {dict(c)}")

summarize_split("train", splits["train"])
summarize_split("val", splits["val"])
summarize_split("test", splits["test"])


In [None]:
# Label mappings
doc_label2id = {c:i for i,c in enumerate(DOC_LABELS)}
doc_id2label = {i:c for c,i in doc_label2id.items()}

tag2id = {t:i for i,t in enumerate(BIO_LABELS)}
id2tag = {i:t for t,i in tag2id.items()}

print("Doc label mapping:", doc_label2id)
print("BIO tag mapping (first 10):", dict(list(tag2id.items())[:10]))


In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, token=False)
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Special tokens: {tokenizer.special_tokens_map}")


In [None]:
# Char span to BIO alignment
def char_spans_to_bio(text: str, markers: List[Dict[str, Any]], encoding) -> List[int]:
    # Base O tags
    labels = ["O"] * len(encoding.input_ids)

    # Token char spans
    token2span = encoding.encodings[0].offsets  # list of (start,end) per token
    
    # Build per-token tag
    for m in markers or []:
        ttype = m.get("type")
        if ttype not in MARKER_TYPES: 
            continue
        s_char = m.get("startIndex")
        e_char = m.get("endIndex")
        if not isinstance(s_char, int) or not isinstance(e_char, int):
            continue
            
        # Mark tokens overlapping this char span
        inside = False
        for ti, (cs, ce) in enumerate(token2span):
            if cs is None or ce is None or cs == ce:
                continue
            # overlap if [cs,ce) intersects [s_char,e_char)
            if max(cs, s_char) < min(ce, e_char):
                prefix = "B" if not inside else "I"
                cand = f"{prefix}-{ttype}"
                # If already labeled this token for a different type, keep first (simple heuristic)
                if labels[ti] == "O":
                    labels[ti] = cand
                inside = True
            elif inside:
                # span ended
                break

    # Map to ids; ensure special tokens get O
    label_ids = []
    for ti, lid in enumerate(labels):
        if token2span[ti] == (0, 0) or token2span[ti][0] is None:
            label_ids.append(tag2id["O"])
        else:
            label_ids.append(tag2id.get(lid, tag2id["O"]))
    return label_ids

def prepare_examples(example):
    text = example["text"]
    enc = tokenizer(
        text,
        padding=False,
        truncation=True,
        max_length=MAX_LEN,
        return_offsets_mapping=True
    )
    # Token labels
    token_label_ids = char_spans_to_bio(text, example.get("markers", []), enc)
    # Doc label
    doc_label_id = doc_label2id[example["conspiracy"]]
    
    return {
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "token_labels": token_label_ids,
        "doc_label": doc_label_id,
    }

def to_hf_dataset(items):
    ds = Dataset.from_list(items)
    return ds.map(prepare_examples, remove_columns=ds.column_names)

ds = DatasetDict({
    "train": to_hf_dataset(splits["train"]),
    "validation": to_hf_dataset(splits["val"]),
    "test": to_hf_dataset(splits["test"]),
})

print("Dataset created successfully")
print(f"Train: {len(ds['train'])} examples")
print(f"Validation: {len(ds['validation'])} examples")
print(f"Test: {len(ds['test'])} examples")


In [None]:
# Multi-task collator
class MultiTaskCollator:
    def __init__(self, tokenizer, label_pad_id=-100):
        self.tokenizer = tokenizer
        self.label_pad_id = label_pad_id

    def __call__(self, features):
        # Extract all fields
        token_labels = [f["token_labels"] for f in features]
        doc_labels = [f["doc_label"] for f in features]
        input_ids_list = [f["input_ids"] for f in features]
        attention_mask_list = [f["attention_mask"] for f in features]
        
        # Pad input_ids and attention_mask
        max_len = max(len(ids) for ids in input_ids_list)
        
        input_ids = []
        attention_mask = []
        for ids, mask in zip(input_ids_list, attention_mask_list):
            pad_len = max_len - len(ids)
            input_ids.append(ids + [self.tokenizer.pad_token_id] * pad_len)
            attention_mask.append(mask + [0] * pad_len)
        
        # Pad token labels
        max_label_len = max(len(tl) for tl in token_labels)
        padded_labels = []
        for tl in token_labels:
            pad_len = max_label_len - len(tl)
            padded_labels.append(tl + [self.label_pad_id] * pad_len)
        
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(padded_labels, dtype=torch.long),
            "doc_labels": torch.tensor(doc_labels, dtype=torch.long),
        }

collator = MultiTaskCollator(tokenizer)
print("Collator created")


In [None]:
# Multi-task model with improved token learning
class MultiTaskDistilRoberta(torch.nn.Module):
    def __init__(self, base_model_name, num_token_labels, num_doc_labels):
        super().__init__()
        self.config = AutoConfig.from_pretrained(base_model_name, token=False)
        self.encoder = AutoModel.from_pretrained(base_model_name, config=self.config, token=False)
        hidden = self.config.hidden_size
        
        # Heads
        self.token_classifier = torch.nn.Linear(hidden, num_token_labels)
        self.doc_classifier = torch.nn.Linear(hidden, num_doc_labels)
        
        # Losses with class weights for better learning
        self.loss_token = torch.nn.CrossEntropyLoss(ignore_index=-100)
        
        # Doc class weights (inverse frequency)
        doc_weights = torch.tensor([0.72, 0.93, 1.83], dtype=torch.float)
        self.loss_doc = torch.nn.CrossEntropyLoss(weight=doc_weights)

    def forward(self, input_ids=None, attention_mask=None, labels=None, doc_labels=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state  # (B,T,H)

        # Token head
        token_logits = self.token_classifier(last_hidden)  # (B,T,Ct)

        # Doc head: mean-pool masked
        mask = attention_mask.unsqueeze(-1)  # (B,T,1)
        summed = (last_hidden * mask).sum(dim=1)  # (B,H)
        lengths = mask.sum(dim=1).clamp(min=1)   # (B,1)
        pooled = summed / lengths
        doc_logits = self.doc_classifier(pooled)  # (B,Cd)

        loss = None
        if labels is not None and doc_labels is not None:
            lt = self.loss_token(token_logits.view(-1, token_logits.size(-1)), labels.view(-1))
            ld = self.loss_doc(doc_logits, doc_labels)
            loss = LAMBDA_TOKEN * lt + LAMBDA_DOC * ld

        return {"loss": loss, "logits": token_logits, "doc_logits": doc_logits}

model = MultiTaskDistilRoberta(MODEL_NAME, num_token_labels=len(BIO_LABELS), num_doc_labels=len(DOC_LABELS))
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Token head: {model.token_classifier}")
print(f"Doc head: {model.doc_classifier}")


In [None]:
# Metrics: Macro Overlap F1 (IoU>=0.5) for markers, Weighted F1 for doc
import itertools
from sklearn.metrics import f1_score as skf1

def chunks_to_spans(tags):
    # Convert BIO tags to spans: (start_idx, end_idx_inclusive, type)
    spans = []
    start, ttype = None, None
    for i, tag in enumerate(tags + ["O"]):  # sentinel O
        if tag.startswith("B-"):
            if start is not None:
                spans.append((start, i-1, ttype))
            start, ttype = i, tag[2:]
        elif tag.startswith("I-"):
            if start is None:
                # treat as B-
                start, ttype = i, tag[2:]
        else:
            if start is not None:
                spans.append((start, i-1, ttype))
                start, ttype = None, None
    return spans

def iou(a_start, a_end, b_start, b_end):
    inter = max(0, min(a_end, b_end) - max(a_start, b_start) + 1)
    union = (a_end - a_start + 1) + (b_end - b_start + 1) - inter
    return inter / union if union > 0 else 0.0

def macro_overlap_f1(true_tags_list, pred_tags_list, types=MARKER_TYPES, iou_thresh=0.5):
    # Build per-type precision/recall via optimal matching (greedy by IoU)
    per_type_f1 = {}
    for t in types:
        tp = fp = fn = 0
        for true_tags, pred_tags in zip(true_tags_list, pred_tags_list):
            true_spans = [(s,e,tt) for (s,e,tt) in chunks_to_spans(true_tags) if tt == t]
            pred_spans = [(s,e,tt) for (s,e,tt) in chunks_to_spans(pred_tags) if tt == t]
            used = set()
            # match preds to trues greedily by IoU
            for ps,pe,_ in pred_spans:
                best_iou, best_j = 0.0, -1
                for j,(ts,te,_) in enumerate(true_spans):
                    if j in used: continue
                    score = iou(ps,pe,ts,te)
                    if score > best_iou:
                        best_iou, best_j = score, j
                if best_iou >= iou_thresh:
                    tp += 1; used.add(best_j)
                else:
                    fp += 1
            fn += (len(true_spans) - len(used))
        prec = tp / (tp + fp) if (tp+fp)>0 else 0.0
        rec  = tp / (tp + fn) if (tp+fn)>0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
        per_type_f1[t] = f1
    macro_f1 = float(np.mean(list(per_type_f1.values()))) if per_type_f1 else 0.0
    return macro_f1, per_type_f1

def compute_metrics_fn(eval_pred):
    # Unpack EvalPrediction
    if hasattr(eval_pred, "predictions"):
        logits = eval_pred.predictions
        labels = eval_pred.label_ids
    else:
        logits, labels = eval_pred
    if isinstance(logits, (tuple, list)):
        logits = logits[0]
    if isinstance(labels, (tuple, list)):
        labels = labels[0]
    logits = np.asarray(logits)
    labels = np.asarray(labels)
    if logits.ndim < 2:
        return {"token_macro_overlap_f1": 0.0}
    preds = logits.argmax(axis=-1) if logits.ndim >= 3 else logits

    # Build tag sequences
    pred_tags, true_tags = [], []
    for p_seq, l_seq in zip(preds, labels):
        seq_pred, seq_true = [], []
        p_seq = np.ravel(p_seq)
        l_seq = np.ravel(l_seq)
        for p_id, l_id in zip(p_seq, l_seq):
            l_id = int(l_id)
            if l_id == -100:
                continue
            seq_pred.append(id2tag[int(p_id)])
            seq_true.append(id2tag[l_id])
        if seq_true:
            pred_tags.append(seq_pred)
            true_tags.append(seq_true)

    macro_f1, per_type = macro_overlap_f1(true_tags, pred_tags)
    # Return only the primary token metric for trainer
    out = {"token_macro_overlap_f1": macro_f1}
    # Optionally include per-type
    for k,v in per_type.items():
        out[f"f1_{k}"] = v
    return out

print("Metrics updated: Macro Overlap F1 (IoU>=0.5) for markers")


In [None]:
# Training arguments
args = TrainingArguments(
    output_dir="/content/outputs",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=WEIGHT_DECAY,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    warmup_ratio=WARMUP_RATIO,
    load_best_model_at_end=True,
    metric_for_best_model="eval_token_macro_overlap_f1",
    greater_is_better=True,
    report_to="none",
    seed=SEED,
    dataloader_drop_last=False,
    remove_unused_columns=False,
)

class WrappedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        doc_labels = inputs.get("doc_labels")
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            labels=labels,
            doc_labels=doc_labels,
        )
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss

trainer = WrappedTrainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics_fn,
)

print("Trainer created")


In [None]:
# Training
print("Starting training...")
trainer.train()
print("Training completed!")


In [None]:
# Evaluation functions
# Competition metric for doc task: Weighted F1 on binary mapping (Can't tell -> No)
from sklearn.metrics import f1_score as sk_f1, classification_report as sk_clf_report

def _map_to_binary(labels: list[int]) -> list[int]:
    # DOC_LABELS = ["No", "Yes", "Can't tell"]
    # Map: No->0, Yes->1, Can't tell->0
    return [0 if DOC_LABELS[l] in ("No", "Can't tell") else 1 for l in labels]

def evaluate_doc(dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collator)
    model.eval()
    preds, trues = [], []
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            doc_labels = batch["doc_labels"].to(device)
            out = model(input_ids=input_ids, attention_mask=attention_mask, labels=None, doc_labels=None)
            pred = out["doc_logits"].argmax(dim=-1)
            preds.extend(pred.cpu().tolist())
            trues.extend(doc_labels.cpu().tolist())

    # Map to binary
    preds_bin = _map_to_binary(preds)
    trues_bin = _map_to_binary(trues)

    weighted_f1 = sk_f1(trues_bin, preds_bin, average="weighted")
    print(f"Doc Weighted-F1 (binary Yes/No, Can't tell->No): {weighted_f1:.4f}")
    print(sk_clf_report(trues_bin, preds_bin, target_names=["No","Yes"]))
    return weighted_f1

print("Evaluation functions updated (Weighted F1, binary mapping)")


In [None]:
# Validation results
print("=== VALIDATION RESULTS ===")
val_token_metrics = trainer.evaluate()
print("Validation (token) metrics:", val_token_metrics)

print("\nValidation (doc) metrics:")
val_doc_f1 = evaluate_doc(ds["validation"])


In [None]:
# Test results
print("=== TEST RESULTS ===")
test_token_metrics = trainer.evaluate(ds["test"])
print("Test (token) metrics:", test_token_metrics)

print("\nTest (doc) metrics:")
test_doc_f1 = evaluate_doc(ds["test"])


In [None]:
# Summary
print("=== FINAL SUMMARY ===")
print(f"Token F1 - Val: {val_token_metrics['eval_token_f1']:.4f}, Test: {test_token_metrics['eval_token_f1']:.4f}")
print(f"Doc Macro-F1 - Val: {val_doc_f1:.4f}, Test: {test_doc_f1:.4f}")
print(f"Loss weights used: token={LAMBDA_TOKEN}, doc={LAMBDA_DOC}")
print(f"Training completed with {EPOCHS} epochs")
