In [None]:
# INSTALL & IMPORTS (run once)
!pip install -q transformers datasets jsonlines torch scikit-learn tqdm

import os
import json
import jsonlines
from pathlib import Path
from tqdm.auto import tqdm
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score

print('torch:', torch.__version__)
print('Using device:', 'cuda' if torch.cuda.is_available() else 'cpu')


torch: 2.8.0+cu126
Using device: cuda


In [None]:
# CONFIGURATION - adjust as needed
FOLD_ID = 1
BATCH_SIZE = 16
EPOCHS = 6
LR = 2e-5
MAX_LEN = 256
MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

DATA_DIR = 'data/scifact/data'
OUTPUT_DIR = 'output/pairwise_pubmedbert'
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
print('Config set. Output path:', OUTPUT_DIR)

Config set. Output path: output/pairwise_pubmedbert


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!rm -rf cis5300_project

In [None]:
!git clone https://github.com/asxd-10/cis5300_project.git

Cloning into 'cis5300_project'...
remote: Enumerating objects: 248, done.[K
remote: Counting objects: 100% (248/248), done.[K
remote: Compressing objects: 100% (211/211), done.[K
remote: Total 248 (delta 132), reused 97 (delta 30), pack-reused 0 (from 0)[K
Receiving objects: 100% (248/248), 14.21 MiB | 8.87 MiB/s, done.
Resolving deltas: 100% (132/132), done.
Updating files: 100% (57/57), done.


In [None]:
# Data utilities (adapted)
import jsonlines
from typing import List, Dict, Any

class Claim:
    def __init__(self, id: int, claim: str, evidence: Dict = None, label: str = None, cited_doc_ids: List = None):
        self.id = id
        self.claim = claim
        self.evidence = evidence or {}
        self.label = label
        self.cited_doc_ids = cited_doc_ids or []

    def to_dict(self) -> Dict:
        return {'id': self.id,'claim': self.claim,'evidence': self.evidence,'label': self.label,'cited_doc_ids': self.cited_doc_ids}

    @classmethod
    def from_dict(cls, data: Dict):
        raw_label = data.get('label')
        evidence = data.get('evidence', {})
        if raw_label is not None:
            label = raw_label
        else:
            labels = []
            for doc_ev in evidence.values():
                for sent in doc_ev:
                    labels.append(sent.get('label'))
            if 'CONTRADICT' in labels:
                label = 'CONTRADICT'
            elif 'SUPPORT' in labels:
                label = 'SUPPORT'
            else:
                label = 'NOT_ENOUGH_INFO'
        return cls(id=data['id'], claim=data['claim'], evidence=evidence, label=label, cited_doc_ids=data.get('cited_doc_ids', []))

class Document:
    def __init__(self, doc_id: int, title: str, abstract: List[str]):
        self.doc_id = int(doc_id)
        self.title = title
        self.abstract = abstract

    def to_dict(self) -> Dict:
        return {'doc_id': self.doc_id,'title': self.title,'abstract': self.abstract}

    @classmethod
    def from_dict(cls, data: Dict):
        return cls(doc_id=int(data['doc_id']), title=data.get('title',''), abstract=data.get('abstract', []))


def load_claims(filepath: str):
    claims = []
    with jsonlines.open(filepath) as reader:
        for obj in reader:
            claims.append(Claim.from_dict(obj))
    return claims

def load_corpus(filepath: str):
    corpus = {}
    with jsonlines.open(filepath) as reader:
        for obj in reader:
            doc = Document.from_dict(obj)
            corpus[doc.doc_id] = doc
    return corpus

print('Data utilities loaded')


Data utilities loaded


In [None]:
# LOAD DATA (fold-specific). Adjust paths if needed
corpus = load_corpus(f'{DATA_DIR}/corpus.jsonl')
print('Corpus size:', len(corpus))

train_claims = load_claims(f'{DATA_DIR}/cross_validation/fold_{FOLD_ID}/claims_train_{FOLD_ID}.jsonl')
dev_claims = load_claims(f'{DATA_DIR}/cross_validation/fold_{FOLD_ID}/claims_dev_{FOLD_ID}.jsonl')
print('Train claims:', len(train_claims), 'Dev claims:', len(dev_claims))

# Quick inspect
print('\nExample claim (train):')
print(train_claims[0].id, train_claims[0].claim)
print('\nExample doc (from corpus):')
first_doc_id = train_claims[0].cited_doc_ids[0] if train_claims[0].cited_doc_ids else None
if first_doc_id and int(first_doc_id) in corpus:
    doc = corpus[int(first_doc_id)]
    print('doc id', doc.doc_id, 'num sents', len(doc.abstract))


Corpus size: 5183
Train claims: 887 Dev claims: 222

Example claim (train):
0 0-dimensional biomaterials lack inductive properties.

Example doc (from corpus):
doc id 31715818 num sents 4


In [None]:
# Sentence-pair dataset: creates (claim, sentence) examples for sentences from cited docs
class SciFactSentencePairDataset(Dataset):
    def __init__(self, claims, corpus, tokenizer, max_len=256, mode='train'):
        self.examples = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mode = mode
        self.label_map = {'SUPPORT': 0, 'CONTRADICT': 1, 'NOT_ENOUGH_INFO': 2}

        for claim in claims:
            for doc_id in claim.cited_doc_ids:
                # doc_id in files may be int already
                doc_int = int(doc_id)
                if doc_int not in corpus:
                    continue
                doc = corpus[doc_int]
                for sent_idx, sent in enumerate(doc.abstract):
                    # ground truth: is this sentence evidence for claim?
                    is_evidence = 0
                    evidence_label = 'NOT_ENOUGH_INFO'
                    if claim.evidence and str(doc.doc_id) in claim.evidence:
                        for ev in claim.evidence[str(doc.doc_id)]:
                            if sent_idx in ev.get('sentences', []):
                                is_evidence = 1
                                evidence_label = ev.get('label')
                                break

                    # label for the claim-level verdict
                    claim_label = self.label_map.get(claim.label, 2)

                    # Build an example
                    self.examples.append({
                        'claim_id': claim.id,
                        'doc_id': doc.doc_id,
                        'sent_idx': sent_idx,
                        'claim': claim.claim,
                        'sentence': sent,
                        'is_evidence': is_evidence,
                        'claim_label': claim_label,
                        'evidence_label_str': evidence_label
                    })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        # tokenize as pair
        encoding = self.tokenizer(
            ex['claim'], ex['sentence'],
            truncation='only_second',
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        item = {k: v.squeeze(0) for k, v in encoding.items()}
        item['is_evidence'] = torch.tensor(ex['is_evidence'], dtype=torch.float)
        item['claim_label'] = torch.tensor(ex['claim_label'], dtype=torch.long)
        item['claim_id'] = ex['claim_id']
        item['doc_id'] = ex['doc_id']
        item['sent_idx'] = ex['sent_idx']
        return item

print('Dataset class defined')


Dataset class defined


In [None]:
# Tokenizer and Datasets
print('Loading tokenizer:', MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = SciFactSentencePairDataset(train_claims, corpus, tokenizer, max_len=MAX_LEN, mode='train')
dev_dataset = SciFactSentencePairDataset(dev_claims, corpus, tokenizer, max_len=MAX_LEN, mode='dev')

print('Train examples (sentence pairs):', len(train_dataset))
print('Dev examples (sentence pairs):', len(dev_dataset))

# Quick class balance
from collections import Counter
train_evidence_counts = Counter([ex['is_evidence'] for ex in train_dataset.examples])
print('Train evidence counts (0/1):', train_evidence_counts)

train_claim_label_counts = Counter([ex['claim_label'] for ex in train_dataset.examples])
print('Train claim label counts (0:S,1:C,2:NEI):', train_claim_label_counts)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)


Loading tokenizer: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
Train examples (sentence pairs): 9204
Dev examples (sentence pairs): 2343
Train evidence counts (0/1): Counter({0: 8091, 1: 1113})
Train claim label counts (0:S,1:C,2:NEI): Counter({0: 3886, 2: 3027, 1: 2291})


In [None]:
# Model: PubMedBERT encoder + two heads
class PubMedBERT_SciFact(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden = self.encoder.config.hidden_size
        # Evidence classifier (binary) on pooled output
        self.evidence_head = nn.Linear(hidden, 1)
        # Claim-level classifier (3-way) -- we will aggregate sentence scores per claim for final verdict
        self.claim_classifier = nn.Linear(hidden, 3)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = outputs.last_hidden_state[:, 0, :]  # CLS
        evidence_logit = self.evidence_head(pooled).squeeze(-1)  # [batch]
        claim_logit = self.claim_classifier(pooled)  # [batch, 3]
        return evidence_logit, claim_logit

model = PubMedBERT_SciFact(MODEL_NAME).to(DEVICE)
print('Model instantiated. Hidden size:', model.encoder.config.hidden_size)


Model instantiated. Hidden size: 768


In [None]:
# Losses and optimizer
# Evidence: BCEWithLogits
# Claim-level: CrossEntropy (we will aggregate per-claim during training/validation)

loss_fn_evidence = nn.BCEWithLogitsLoss()
loss_fn_claim = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LR)

# Scheduler placeholder (optional)
# total_steps = len(train_loader) * EPOCHS
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

print('Optim and loss set')


Optim and loss set


In [None]:
# TRAINING LOOP (pairwise)
from collections import defaultdict

def train_one_epoch():
    model.train()
    running_loss = 0.0
    for batch in tqdm(train_loader, desc='Train'):
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        evidence_labels = batch['is_evidence'].to(DEVICE)
        claim_labels = batch['claim_label'].to(DEVICE)

        optimizer.zero_grad()
        ev_logits, claim_logits = model(input_ids, attention_mask)

        # Evidence loss (per sentence)
        loss_ev = loss_fn_evidence(ev_logits, evidence_labels)

        # Claim loss: aggregate sentence-level pooled representations per claim id in this batch
        # We'll compute claim_logits per sentence then average logits for sentences of same claim
        # Build mapping claim_id -> indices
        claim_ids = batch['claim_id']
        claim_to_idxs = defaultdict(list)
        for i, cid in enumerate(claim_ids):
            claim_to_idxs[cid].append(i)

        # average claim_logits across sentences for same claim to get claim-level logit
        agg_claim_logits = []
        agg_claim_labels = []
        for cid, idxs in claim_to_idxs.items():
            logits_avg = claim_logits[idxs].mean(dim=0, keepdim=True)  # [1,3]
            agg_claim_logits.append(logits_avg)
            # claim label from first index (all same)
            agg_claim_labels.append(claim_labels[idxs[0]])

        agg_claim_logits = torch.cat(agg_claim_logits, dim=0).to(DEVICE)
        agg_claim_labels = torch.stack(agg_claim_labels).to(DEVICE)

        loss_claim = loss_fn_claim(agg_claim_logits, agg_claim_labels)

        # Total loss
        LAMBDA = 2.0  # weight for evidence, tune this
        loss = loss_claim + LAMBDA * loss_ev
        loss.backward()
        optimizer.step()
        # scheduler.step()

        running_loss += loss.item()
    return running_loss / len(train_loader)

print('Training function ready')


Training function ready


In [None]:
# EVALUATION: produce predictions at claim-level and sentence-level for scorer compatible output

def evaluate_and_predict(dev_claims, corpus, tokenizer, output_predictions_path):
    model.eval()
    predictions = []

    # We'll score every cited sentence for each dev claim and choose evidence sentences where evidence_prob > threshold
    THRESH = 0.5

    with torch.no_grad():
        for claim in tqdm(dev_claims, desc='Predicting dev claims'):
            claim_id = claim.id
            pred_entry = {'id': claim_id, 'label': 'NOT_ENOUGH_INFO', 'evidence': {}}

            # if no cited docs, default NEI
            if not claim.cited_doc_ids:
                predictions.append(pred_entry)
                continue

            # For each cited doc, score each sentence
            for doc_id in claim.cited_doc_ids:
                doc_int = int(doc_id)
                if doc_int not in corpus:
                    continue
                doc = corpus[doc_int]
                sent_scores = []
                sent_probs = []
                for sidx, sent in enumerate(doc.abstract):
                    encoding = tokenizer(claim.claim, sent, truncation='only_second', max_length=MAX_LEN, padding='max_length', return_tensors='pt')
                    input_ids = encoding['input_ids'].to(DEVICE)
                    attention_mask = encoding['attention_mask'].to(DEVICE)
                    ev_logits, claim_logits = model(input_ids, attention_mask)
                    ev_prob = torch.sigmoid(ev_logits).item()
                    sent_scores.append((sidx, ev_prob, claim_logits.squeeze(0).cpu().numpy()))
                    sent_probs.append(ev_prob)

                # choose evidence sentences above threshold
                chosen = [ (idx,prob,logits) for idx,prob,logits in sent_scores if prob >= THRESH ]
                # if none chosen but max prob > 0, optionally select top-k (k=1)
                if not chosen and len(sent_scores) > 0:
                    # pick top 1 as fallback
                    best = max(sent_scores, key=lambda x: x[1])
                    chosen = [best]

                if chosen:
                    # decide claim label based on average claim_logits of chosen sentences
                    logits = np.stack([c[2] for c in chosen], axis=0)
                    avg_logits = logits.mean(axis=0)
                    pred_label_idx = int(np.argmax(avg_logits))
                    label_map_rev = {0: 'SUPPORT', 1: 'CONTRADICT', 2: 'NOT_ENOUGH_INFO'}
                    pred_label_str = label_map_rev[pred_label_idx]

                    pred_entry['label'] = pred_label_str
                    pred_entry['evidence'][str(doc_int)] = [{
                        'sentences': [int(c[0]) for c in chosen],
                        'label': pred_label_str
                    }]

            predictions.append(pred_entry)

    # save predictions
    with jsonlines.open(output_predictions_path, mode='w') as writer:
        writer.write_all(predictions)
    return predictions

print('Evaluation function ready')


Evaluation function ready


In [None]:
# RUN TRAINING + VALIDATION
best_f1 = 0.0
best_path = None
for epoch in range(EPOCHS):
    train_loss = train_one_epoch()
    print(f'Epoch {epoch+1}/{EPOCHS} train loss: {train_loss:.4f}')

    preds_path = f"{OUTPUT_DIR}/pubmedbert_fold{FOLD_ID}_epoch{epoch+1}_preds.jsonl"
    _ = evaluate_and_predict(dev_claims, corpus, tokenizer, preds_path)
    # compute sentence-level F1 using the external scorer (user has it). We'll call it if present.
    print('Saved predictions to', preds_path)

print('Training finished')


Train:   0%|          | 0/576 [00:00<?, ?it/s]

Epoch 1/6 train loss: 1.1142


Predicting dev claims:   0%|          | 0/222 [00:00<?, ?it/s]

Saved predictions to output/pairwise_pubmedbert/pubmedbert_fold1_epoch1_preds.jsonl


Train:   0%|          | 0/576 [00:00<?, ?it/s]

Epoch 2/6 train loss: 0.4080


Predicting dev claims:   0%|          | 0/222 [00:00<?, ?it/s]

Saved predictions to output/pairwise_pubmedbert/pubmedbert_fold1_epoch2_preds.jsonl


Train:   0%|          | 0/576 [00:00<?, ?it/s]

Epoch 3/6 train loss: 0.2367


Predicting dev claims:   0%|          | 0/222 [00:00<?, ?it/s]

Saved predictions to output/pairwise_pubmedbert/pubmedbert_fold1_epoch3_preds.jsonl


Train:   0%|          | 0/576 [00:00<?, ?it/s]

Epoch 4/6 train loss: 0.1496


Predicting dev claims:   0%|          | 0/222 [00:00<?, ?it/s]

Saved predictions to output/pairwise_pubmedbert/pubmedbert_fold1_epoch4_preds.jsonl


Train:   0%|          | 0/576 [00:00<?, ?it/s]

Epoch 5/6 train loss: 0.1103


Predicting dev claims:   0%|          | 0/222 [00:00<?, ?it/s]

Saved predictions to output/pairwise_pubmedbert/pubmedbert_fold1_epoch5_preds.jsonl


Train:   0%|          | 0/576 [00:00<?, ?it/s]

Epoch 6/6 train loss: 0.0856


Predicting dev claims:   0%|          | 0/222 [00:00<?, ?it/s]

Saved predictions to output/pairwise_pubmedbert/pubmedbert_fold1_epoch6_preds.jsonl
Training finished


In [43]:
SCORER_PATH = 'src/evaluation/score_claims.py'
if os.path.exists(SCORER_PATH):
    gold_file = f'{DATA_DIR}/cross_validation/fold_{FOLD_ID}/claims_dev_{FOLD_ID}.jsonl'
    pred_file = f"{OUTPUT_DIR}/pubmedbert_fold{FOLD_ID}_epoch{EPOCHS}_preds.jsonl"
    print('Running scorer...')
    !python {SCORER_PATH} --gold {gold_file} --predictions {pred_file}
else:
    print('Scorer not found at', SCORER_PATH)

Running scorer...
CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 222
  Predictions: 222

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    1.0000
  F1:        1.0000

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.2821
  Recall:    0.6475
  F1:        0.3930

Label-only:
  Accuracy:  0.0000

Interpretation:
 Retrieval is excellent (oracle or near-oracle)
  Evidence extraction is competitive
