# Extension: Graph Neural Network for Claim Verification

## Overview

This notebook implements a Graph Neural Network (GNN) extension to improve claim verification by modeling relationships between claims and document sentences.

### Motivation

Current models (SciBERT, PubMedBERT) process sentences independently, missing:
- **Claim-sentence relationships**: Which sentences directly address the claim?
- **Sentence-sentence dependencies**: How do sentences relate to each other?
- **Document structure**: Sequential and semantic relationships

GNNs can explicitly model these relationships through graph attention, potentially improving evidence extraction.

### Architecture

**Graph Construction**:
- **Nodes**: 1 claim node + N sentence nodes (one per sentence)
- **Edges**:
  - Claim ↔ Sentence (all sentences, bidirectional, weighted by cosine similarity)
  - Sentence ↔ Sentence (sequential: adjacent + semantic: similarity > threshold)

**Model**:
- BERT (SciBERT/PubMedBERT) encodes claim and sentences separately **WITH GRADIENTS** (trainable)
- GAT (Graph Attention Network) refines representations
- Hybrid: Concatenate BERT + GNN embeddings
- Classifiers: Label (3-way) + Evidence (binary per sentence)

### Hope to see  - Impact

- **Target**: +3-5% F1 improvement over baseline (with fixes applied)
- **Mechanism**: Fine-tuned BERT + relational modeling through GNN
- **Best Case**: +5-7% F1 if GNN captures important relationships

### Baseline Comparison

| Model | F1 | Notes |
|-------|-----|-------|
| SciBERT Baseline | 24.20% | Milestone 2 |
| PubMedBERT | 39.30% | Sentence-pair architecture |
| **SciBERT + GNN** | 18.98% | Frozen BERT, no NEI training |
| **SciBERT + GNN** | ? | Trainable BERT, with NEI training |

---

In [1]:
# Setup: Mount Google Drive and install dependencies
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import torch
import os
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

Mounted at /content/drive
CUDA available: True
GPU: Tesla T4


In [2]:
# Install dependencies
%pip install -q transformers datasets jsonlines scikit-learn tqdm
%pip install -q torch-geometric  # GNN library

In [3]:
# Clone repository
!rm -rf cis5300_project
!git clone https://github.com/asxd-10/cis5300_project.git

import sys
os.chdir('cis5300_project')
sys.path.append('.')
print(f"Current directory: {os.getcwd()}")

Cloning into 'cis5300_project'...
remote: Enumerating objects: 286, done.[K
remote: Counting objects: 100% (286/286), done.[K
remote: Compressing objects: 100% (249/249), done.[K
remote: Total 286 (delta 160), reused 97 (delta 30), pack-reused 0 (from 0)[K
Receiving objects: 100% (286/286), 14.30 MiB | 12.86 MiB/s, done.
Resolving deltas: 100% (160/160), done.
Current directory: /content/cis5300_project


In [4]:
# Configuration
import torch
import random
import numpy as np

# Model configuration
MODEL_NAME = 'allenai/scibert_scivocab_uncased'  # or 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
USE_PUBMEDBERT = False  # Set to True to use PubMedBERT

if USE_PUBMEDBERT:
    MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'

# GNN configuration
GNN_HIDDEN_DIM = 256  # GNN hidden dimension (can tune: 128, 256, 512)
GNN_NUM_LAYERS = 2  # Number of GAT layers (can tune: 1, 2, 3)
GNN_NUM_HEADS = 4  # Attention heads per layer (can tune: 2, 4, 8)
GNN_DROPOUT = 0.1
SIMILARITY_THRESHOLD = 0.3  # Min similarity for sentence-sentence edges

# Training configuration
BATCH_SIZE = 4  # Smaller batch for graphs (can increase if memory allows)
LEARNING_RATE = 1e-4  # Lower LR for GNN (can tune: 5e-5, 1e-4, 2e-4)
EPOCHS = 6
MAX_SENTENCES = 20  # Max sentences per document
EVIDENCE_LOSS_WEIGHT = 2.0
NEGATIVE_RATIO = 0.3  # Add 30% NEI examples for training

# Hybrid configuration
USE_HYBRID = True  # Concatenate BERT + GNN embeddings (recommended)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print(f"Using device: {device}")
print(f"Model: {MODEL_NAME}")
print(f"GNN: {GNN_NUM_LAYERS} layers, {GNN_NUM_HEADS} heads, hidden_dim={GNN_HIDDEN_DIM}")
print(f"Hybrid mode: {USE_HYBRID}")

Using device: cuda
Model: allenai/scibert_scivocab_uncased
GNN: 2 layers, 4 heads, hidden_dim=256
Hybrid mode: True


In [5]:
# Load data
from src.common.data_utils import load_claims, load_corpus

train_claims = load_claims('data/scifact/data/claims_train.jsonl')
dev_claims = load_claims('data/scifact/data/claims_dev.jsonl')
corpus = load_corpus('data/scifact/data/corpus.jsonl')

print(f"{len(train_claims)} training claims")
print(f"{len(dev_claims)} dev claims")
print(f"{len(corpus)} documents")

809 training claims
300 dev claims
5183 documents


In [6]:
# Graph Construction Utilities
import torch
from torch_geometric.data import Data, Batch
import torch.nn.functional as F

def cosine_similarity_torch(a, b):
    """Compute cosine similarity between two tensors."""
    a_norm = F.normalize(a, p=2, dim=-1)
    b_norm = F.normalize(b, p=2, dim=-1)
    return (a_norm * b_norm).sum(dim=-1)

def build_claim_sentence_graph(claim_emb, sent_embs, similarity_threshold=0.3):
    """
    Build graph connecting claim and sentences.

    Args:
        claim_emb: [768] - claim embedding from BERT
        sent_embs: [N, 768] - sentence embeddings from BERT
        similarity_threshold: Min similarity for sentence-sentence edges

    Returns:
        edge_index: [2, E] - edge connections
        edge_weights: [E] - cosine similarities (optional)
    """
    N = len(sent_embs)
    if N == 0:
        # Empty document - just claim node
        return torch.empty((2, 0), dtype=torch.long), torch.empty((0,), dtype=torch.float)

    edges = []
    weights = []

    # Claim-sentence edges (bidirectional)
    # Node 0 = claim, Nodes 1..N = sentences
    for i in range(N):
        # Claim -> Sentence i
        sim = cosine_similarity_torch(claim_emb.unsqueeze(0), sent_embs[i:i+1]).item()
        edges.append([0, i+1])
        weights.append(sim)

        # Sentence i -> Claim (bidirectional)
        edges.append([i+1, 0])
        weights.append(sim)

    # Sentence-sentence edges
    # Sequential edges (adjacent sentences)
    for i in range(N-1):
        sim = cosine_similarity_torch(sent_embs[i:i+1], sent_embs[i+1:i+2]).item()
        # i -> i+1
        edges.append([i+1, i+2])
        weights.append(sim)
        # i+1 -> i (bidirectional)
        edges.append([i+2, i+1])
        weights.append(sim)

    # Semantic edges (high similarity, not adjacent)
    for i in range(N):
        for j in range(i+2, N):  # Skip adjacent (already added)
            sim = cosine_similarity_torch(sent_embs[i:i+1], sent_embs[j:j+1]).item()
            if sim > similarity_threshold:
                edges.append([i+1, j+1])
                weights.append(sim)
                edges.append([j+1, i+1])
                weights.append(sim)

    if len(edges) == 0:
        # Fallback: at least connect claim to sentences
        for i in range(N):
            edges.append([0, i+1])
            weights.append(1.0)

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_weights = torch.tensor(weights, dtype=torch.float)

    return edge_index, edge_weights

print("Graph construction utilities defined")

Graph construction utilities defined


In [7]:
# GNN-Enhanced Model Architecture
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch_geometric.nn import GATConv

class GNNClaimVerifier(nn.Module):
    """
    GNN-enhanced claim verification model.
    Combines BERT encoding with Graph Attention Network.
    """

    def __init__(self, model_name='allenai/scibert_scivocab_uncased',
                 num_labels=3, max_sentences=20,
                 gnn_hidden_dim=256, gnn_num_layers=2, gnn_num_heads=4,
                 gnn_dropout=0.1, use_hybrid=True):
        super().__init__()

        # BERT encoder
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size  # 768 for SciBERT

        # GNN layers
        self.gnn_layers = nn.ModuleList()
        self.gnn_hidden_dim = gnn_hidden_dim
        self.use_hybrid = use_hybrid

        # First GAT layer
        self.gnn_layers.append(
            GATConv(hidden_size, gnn_hidden_dim, heads=gnn_num_heads,
                   dropout=gnn_dropout, concat=True)
        )

        # Additional GAT layers
        for _ in range(gnn_num_layers - 1):
            self.gnn_layers.append(
                GATConv(gnn_hidden_dim * gnn_num_heads, gnn_hidden_dim,
                       heads=gnn_num_heads, dropout=gnn_dropout, concat=True)
            )

        # Final GNN output dimension
        gnn_output_dim = gnn_hidden_dim * gnn_num_heads

        # Classification heads
        if use_hybrid:
            # Hybrid: Concatenate BERT + GNN
            classifier_input_dim = hidden_size + gnn_output_dim
        else:
            # GNN only
            classifier_input_dim = gnn_output_dim

        # Label classifier (uses claim node)
        self.label_classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, classifier_input_dim // 2),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(classifier_input_dim // 2, num_labels)
        )

        # Evidence classifier (uses sentence nodes)
        self.evidence_classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, classifier_input_dim // 2),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(classifier_input_dim // 2, 1)  # Binary
        )

        self.max_sentences = max_sentences
        self.hidden_size = hidden_size
        self.gnn_output_dim = gnn_output_dim

    def encode_claim_and_sentences(self, claim_text, sentences, device):
        """
        Encode claim and sentences separately using BERT.

        Returns:
            claim_emb: [768] - claim embedding
            sent_embs: [N, 768] - sentence embeddings
        """
        # Encode claim
        claim_encoding = self.tokenizer(
            claim_text,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(device)

        claim_output = self.encoder(**claim_encoding)
        claim_emb = claim_output.last_hidden_state[:, 0, :].squeeze(0)  # [CLS] token

        # Encode sentences
        sent_embs = []
        for sent in sentences:
            sent_encoding = self.tokenizer(
                sent,
                max_length=128,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            ).to(device)

            sent_output = self.encoder(**sent_encoding)
            sent_emb = sent_output.last_hidden_state[:, 0, :].squeeze(0)  # [CLS] token
            sent_embs.append(sent_emb)

        if len(sent_embs) > 0:
            sent_embs = torch.stack(sent_embs)  # [N, 768]
        else:
            # Empty document
            sent_embs = torch.zeros((0, claim_emb.size(0)), device=device)

        return claim_emb, sent_embs

    def forward(self, claim_tokens_list, sentence_tokens_lists, num_sentences_list, similarity_thresholds):
        """
        Forward pass with tokenized inputs - BERT is now trainable.

        Args:
            claim_tokens_list: List of tokenized claims [batch_size]
            sentence_tokens_lists: List of lists of tokenized sentences [batch_size, num_sents]
            num_sentences_list: List of actual sentence counts per example
            similarity_thresholds: List of similarity thresholds per example

        Returns:
            label_logits: [batch, 3]
            evidence_logits: [batch, max_sentences]
        """
        batch_size = len(claim_tokens_list)
        all_graphs = []

        # Encode each example and build graph (WITH gradients - BERT trainable)
        for i in range(batch_size):
            claim_tokens = claim_tokens_list[i].to(self.encoder.device)
            sentence_tokens_list = sentence_tokens_lists[i]
            num_sents = num_sentences_list[i]
            sim_thresh = similarity_thresholds[i]

            # Encode claim (WITH gradients - BERT trainable)
            claim_output = self.encoder(**claim_tokens)
            claim_emb = claim_output.last_hidden_state[:, 0, :].squeeze(0)  # [768]

            # Encode sentences (WITH gradients)
            sent_embs = []
            for sent_tokens in sentence_tokens_list:
                sent_tokens_device = sent_tokens.to(self.encoder.device)
                sent_output = self.encoder(**sent_tokens_device)
                sent_emb = sent_output.last_hidden_state[:, 0, :].squeeze(0)  # [768]
                sent_embs.append(sent_emb)

            if len(sent_embs) > 0:
                sent_embs = torch.stack(sent_embs)  # [N, 768]
            else:
                sent_embs = torch.zeros((0, claim_emb.size(0)), device=claim_emb.device)

            # Build graph
            edge_index, edge_weights = build_claim_sentence_graph(
                claim_emb, sent_embs, sim_thresh
            )

            # Create node features
            node_features = torch.cat([claim_emb.unsqueeze(0), sent_embs], dim=0)

            # Create graph
            graph = Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_weights
            )
            all_graphs.append(graph)

        # Batch graphs
        batched_graphs = Batch.from_data_list(all_graphs).to(claim_emb.device)

        # Process through GNN
        x = batched_graphs.x
        edge_index = batched_graphs.edge_index
        bert_embeddings = x.clone()

        # Apply GNN layers
        for gnn_layer in self.gnn_layers:
            x = gnn_layer(x, edge_index)
            x = F.relu(x)

        # Split and classify
        batch_assignments = batched_graphs.batch
        label_logits_list = []
        evidence_logits_list = []

        for graph_idx in range(batch_size):
            graph_mask = (batch_assignments == graph_idx)
            graph_nodes_gnn = x[graph_mask]
            graph_nodes_bert = bert_embeddings[graph_mask]

            claim_node_gnn = graph_nodes_gnn[0]
            claim_node_bert = graph_nodes_bert[0]
            sentence_nodes_gnn = graph_nodes_gnn[1:]
            sentence_nodes_bert = graph_nodes_bert[1:]

            # Pad sentences
            num_sents = len(sentence_nodes_gnn)
            if num_sents < self.max_sentences:
                padding_gnn = torch.zeros(self.max_sentences - num_sents,
                                           self.gnn_output_dim, device=sentence_nodes_gnn.device)
                sentence_nodes_gnn = torch.cat([sentence_nodes_gnn, padding_gnn])
                padding_bert = torch.zeros(self.max_sentences - num_sents,
                                           self.hidden_size, device=sentence_nodes_bert.device)
                sentence_nodes_bert = torch.cat([sentence_nodes_bert, padding_bert])
            else:
                sentence_nodes_gnn = sentence_nodes_gnn[:self.max_sentences]
                sentence_nodes_bert = sentence_nodes_bert[:self.max_sentences]

            # Hybrid
            if self.use_hybrid:
                claim_combined = torch.cat([claim_node_bert, claim_node_gnn], dim=-1)
                sentence_combined = torch.cat([sentence_nodes_bert, sentence_nodes_gnn], dim=-1)
            else:
                claim_combined = claim_node_gnn
                sentence_combined = sentence_nodes_gnn

            label_logits_list.append(self.label_classifier(claim_combined))
            evidence_logits_list.append(self.evidence_classifier(sentence_combined).squeeze(-1))

        label_logits = torch.stack(label_logits_list)
        evidence_logits = torch.stack(evidence_logits_list)

        return label_logits, evidence_logits
print("GNN model defined")


GNN model defined


In [8]:
# Dataset Class for GNN
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import random

class SciFactGNNDataset(Dataset):
    """Dataset that builds graphs for GNN training. Returns tokenized inputs for end-to-end training."""

    def __init__(self, claims, corpus, tokenizer, device,
                 max_sentences=20, similarity_threshold=0.3, mode='train',
                 negative_ratio=0.3, random_seed=42):
        self.claims = claims
        self.corpus = corpus
        self.tokenizer = tokenizer
        self.device = device
        self.max_sentences = max_sentences
        self.similarity_threshold = similarity_threshold
        self.mode = mode
        self.label_map = {'SUPPORT': 0, 'CONTRADICT': 1, 'NOT_ENOUGH_INFO': 2}
        self.negative_ratio = negative_ratio

        # Build examples list
        self.examples = []

        if mode == 'train':
            # Positive examples: claims with evidence
            claims_with_evidence = [c for c in claims if c.evidence and c.label]
            for claim in claims_with_evidence:
                if claim.cited_doc_ids:
                    doc_id = int(claim.cited_doc_ids[0])
                    if doc_id in corpus:
                        self.examples.append({
                            'claim': claim,
                            'doc_id': doc_id,
                            'is_negative': False
                        })

            # Negative examples: NEI claims with random documents
            nei_claims = [c for c in claims if not c.evidence or c.label == 'NOT_ENOUGH_INFO']
            num_negatives = int(len(claims_with_evidence) * negative_ratio)
            available_doc_ids = list(corpus.keys())
            random.seed(random_seed)

            for i in range(num_negatives):
                claim = random.choice(nei_claims)
                cited_doc_ids = [int(d) for d in claim.cited_doc_ids] if claim.cited_doc_ids else []
                candidate_docs = [d for d in available_doc_ids if d not in cited_doc_ids]
                doc_id = random.choice(candidate_docs) if candidate_docs else random.choice(available_doc_ids)
                self.examples.append({
                    'claim': claim,
                    'doc_id': doc_id,
                    'is_negative': True
                })

            print(f"Dataset: {len(self.examples)} examples ({len(claims_with_evidence)} positive, {num_negatives} negative)")
        else:
            # Dev: all claims
            for claim in claims:
                if claim.cited_doc_ids:
                    doc_id = int(claim.cited_doc_ids[0])
                    if doc_id in corpus:
                        self.examples.append({
                            'claim': claim,
                            'doc_id': doc_id,
                            'is_negative': False
                        })
            print(f"Dataset: {len(self.examples)} examples ({mode})")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        claim = example['claim']
        doc_id = example['doc_id']
        is_negative = example['is_negative']

        doc = self.corpus[doc_id]
        sentences = doc.abstract[:self.max_sentences]

        # Tokenize claim and sentences
        claim_tokens = self.tokenizer(
            claim.claim,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        sentence_tokens_list = []
        for sent in sentences:
            sent_tokens = self.tokenizer(
                sent,
                max_length=128,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            sentence_tokens_list.append(sent_tokens)

        # Labels
        if is_negative:
            label = 2  # NOT_ENOUGH_INFO
        else:
            label = self.label_map.get(claim.label, 2)

        # Evidence mask
        evidence_mask = torch.zeros(self.max_sentences)
        if not is_negative:
            doc_id_str = str(doc.doc_id)
            if doc_id_str in claim.evidence:
                for ev_entry in claim.evidence[doc_id_str]:
                    for sent_idx in ev_entry.get('sentences', []):
                        if sent_idx < self.max_sentences:
                            evidence_mask[sent_idx] = 1.0

        return {
            'claim_tokens': claim_tokens,
            'sentence_tokens_list': sentence_tokens_list,
            'num_sentences': len(sentences),
            'similarity_threshold': self.similarity_threshold,
            'claim_node_idx': 0,
            'label': torch.tensor(label),
            'evidence_mask': evidence_mask
        }

print("GNN dataset class defined (FIXED: returns tokenized inputs, includes NEI)")


GNN dataset class defined (FIXED: returns tokenized inputs, includes NEI)


In [9]:
# Training Setup
# Initialize model
model = GNNClaimVerifier(
    model_name=MODEL_NAME,
    gnn_hidden_dim=GNN_HIDDEN_DIM,
    gnn_num_layers=GNN_NUM_LAYERS,
    gnn_num_heads=GNN_NUM_HEADS,
    gnn_dropout=GNN_DROPOUT,
    use_hybrid=USE_HYBRID
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Create datasets
train_dataset = SciFactGNNDataset(
    train_claims, corpus, model.tokenizer, device,
    max_sentences=MAX_SENTENCES,
    similarity_threshold=SIMILARITY_THRESHOLD,
    mode='train',
    negative_ratio=NEGATIVE_RATIO
)

dev_dataset = SciFactGNNDataset(
    dev_claims, corpus, model.tokenizer, device,
    max_sentences=MAX_SENTENCES,
    similarity_threshold=SIMILARITY_THRESHOLD,
    mode='dev'
)

# Custom collate function
def collate_graphs(batch):
    claim_tokens_list = [item['claim_tokens'] for item in batch]
    sentence_tokens_lists = [item['sentence_tokens_list'] for item in batch]
    num_sentences_list = [item['num_sentences'] for item in batch]
    similarity_thresholds = [item['similarity_threshold'] for item in batch]
    labels = torch.stack([item['label'] for item in batch])
    evidence_masks = torch.stack([item['evidence_mask'] for item in batch])

    return {
        'claim_tokens_list': claim_tokens_list,
        'sentence_tokens_lists': sentence_tokens_lists,
        'num_sentences_list': num_sentences_list,
        'similarity_thresholds': similarity_thresholds,
        'labels': labels,
        'evidence_masks': evidence_masks
    }
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_graphs)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_graphs)

print(f"Training batches: {len(train_loader)}")
print(f"Dev batches: {len(dev_loader)}")

# Optimizer and loss
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
label_criterion = nn.CrossEntropyLoss()
evidence_criterion = nn.BCEWithLogitsLoss()

print("Training setup complete")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Model parameters: 114,976,260
Trainable parameters: 114,976,260
Dataset: 656 examples (505 positive, 151 negative)
Dataset: 300 examples (dev)
Training batches: 164
Dev batches: 75
Training setup complete


In [10]:
# Training Loop
from tqdm import tqdm
import os

model.train()
os.makedirs('models/claim_verifier', exist_ok=True)

for epoch in range(EPOCHS):
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for batch in progress_bar:
        claim_tokens_list = batch['claim_tokens_list']
        sentence_tokens_lists = batch['sentence_tokens_lists']
        num_sentences_list = batch['num_sentences_list']
        similarity_thresholds = batch['similarity_thresholds']
        labels = batch['labels'].to(device)
        evidence_masks = batch['evidence_masks'].to(device)

        # Forward
        label_logits, evidence_logits = model(
            claim_tokens_list,
            sentence_tokens_lists,
            num_sentences_list,
            similarity_thresholds
        )

        # Losses
        label_loss = label_criterion(label_logits, labels)
        evidence_loss = evidence_criterion(evidence_logits, evidence_masks)
        loss = label_loss + EVIDENCE_LOSS_WEIGHT * evidence_loss

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Stats
        total_loss += loss.item()
        pred = label_logits.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.1f}%'
        })

    avg_loss = total_loss / len(train_loader)
    accuracy = 100 * correct / total

    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Label Accuracy: {accuracy:.2f}%")

    # Save checkpoint
    checkpoint_path = f'models/claim_verifier/gnn_epoch{epoch+1}.pt'
    torch.save(model.state_dict(), checkpoint_path)
    print(f"  Saved: {checkpoint_path}")

print("\nTraining complete!")


Epoch 1/6: 100%|██████████| 164/164 [04:28<00:00,  1.64s/it, loss=1.8270, acc=48.5%]



Epoch 1 Summary:
  Loss: 2.0679
  Label Accuracy: 48.48%
  Saved: models/claim_verifier/gnn_epoch1.pt


Epoch 2/6: 100%|██████████| 164/164 [04:12<00:00,  1.54s/it, loss=3.2523, acc=58.2%]



Epoch 2 Summary:
  Loss: 1.7405
  Label Accuracy: 58.23%
  Saved: models/claim_verifier/gnn_epoch2.pt


Epoch 3/6: 100%|██████████| 164/164 [04:16<00:00,  1.56s/it, loss=1.2993, acc=69.2%]



Epoch 3 Summary:
  Loss: 1.3711
  Label Accuracy: 69.21%
  Saved: models/claim_verifier/gnn_epoch3.pt


Epoch 4/6: 100%|██████████| 164/164 [04:15<00:00,  1.56s/it, loss=0.6202, acc=78.2%]



Epoch 4 Summary:
  Loss: 1.0205
  Label Accuracy: 78.20%
  Saved: models/claim_verifier/gnn_epoch4.pt


Epoch 5/6: 100%|██████████| 164/164 [04:15<00:00,  1.56s/it, loss=1.6540, acc=85.1%]



Epoch 5 Summary:
  Loss: 0.7544
  Label Accuracy: 85.06%
  Saved: models/claim_verifier/gnn_epoch5.pt


Epoch 6/6: 100%|██████████| 164/164 [04:13<00:00,  1.55s/it, loss=0.1205, acc=88.3%]



Epoch 6 Summary:
  Loss: 0.5764
  Label Accuracy: 88.26%
  Saved: models/claim_verifier/gnn_epoch6.pt

Training complete!


In [12]:
# Evaluation: Generate Predictions
import jsonlines
import subprocess
from tqdm import tqdm
import torch
from torch_geometric.data import Batch, Data

def generate_predictions_gnn(claims, corpus, model, device, threshold=0.5):
    """Generate predictions using GNN model."""
    model.eval()
    predictions = []

    with torch.no_grad():
        for claim in tqdm(claims, desc="Generating predictions"):
            if not hasattr(claim, 'cited_doc_ids') or not claim.cited_doc_ids:
                predictions.append({'id': claim.id, 'label': 'NOT_ENOUGH_INFO', 'evidence': {}})
                continue

            doc_id = int(claim.cited_doc_ids[0])
            if doc_id not in corpus:
                predictions.append({'id': claim.id, 'label': 'NOT_ENOUGH_INFO', 'evidence': {}})
                continue

            doc = corpus[doc_id]
            sentences = doc.abstract[:MAX_SENTENCES]

            # Tokenize claim & sentences exactly like the training dataset does
            claim_tokens = model.tokenizer(
                claim.claim,
                max_length=128,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            ).to(device)

            sentence_tokens_list = []
            for sent in sentences:
                sent_tokens = model.tokenizer(
                    sent,
                    max_length=128,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                ).to(device)
                sentence_tokens_list.append(sent_tokens)

            # Model expects lists of length batch_size (here = 1)
            claim_tokens_list       = [claim_tokens]          # list with 1 element
            sentence_tokens_lists   = [sentence_tokens_list]  # list of list

            # Prepare additional args
            num_sents = len(sentences)
            num_sentences_list = [num_sents]
            similarity_thresholds = torch.tensor([SIMILARITY_THRESHOLD],
                                                 dtype=torch.float, device=device)

            #  ACTUAL FORWARD CALL
            label_logits, evidence_logits = model(
                claim_tokens_list=claim_tokens_list,
                sentence_tokens_lists=sentence_tokens_lists,
                num_sentences_list=num_sentences_list,
                similarity_thresholds=similarity_thresholds
            )

            # Get predictions
            pred_label_idx = label_logits.argmax(dim=1).item()
            label_map = {0: 'SUPPORT', 1: 'CONTRADICT', 2: 'NOT_ENOUGH_INFO'}
            pred_label = label_map[pred_label_idx]

            evidence_probs = torch.sigmoid(evidence_logits[0])
            num_sents = len(sentences)
            pred_evidence_sents = [i for i, prob in enumerate(evidence_probs[:num_sents])
                                   if prob > threshold]

            prediction = {'id': claim.id, 'label': pred_label, 'evidence': {}}
            if pred_evidence_sents:
                prediction['evidence'][str(doc_id)] = [{
                    'sentences': pred_evidence_sents,
                    'label': pred_label
                }]
            predictions.append(prediction)  # Always append here

    return predictions

# Test different thresholds
print("EVALUATION ON DEV SET")
best_f1 = 0.0
best_threshold = 0.5
best_precision = 0.0
best_recall = 0.0
threshold_results = []

for threshold in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]:
    print(f"\n--- Testing threshold: {threshold} ---")

    predictions = generate_predictions_gnn(dev_claims, corpus, model, device, threshold)

    # Save predictions
    output_path = f'output/dev/gnn_thresh{int(threshold*100)}.jsonl'
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with jsonlines.open(output_path, 'w') as writer:
        writer.write_all(predictions)

    # Evaluate
    result = subprocess.run(
        ['python', 'src/evaluation/score_claims.py',
         '--gold', 'data/scifact/data/claims_dev.jsonl',
         '--predictions', output_path],
        capture_output=True,
        text=True
    )

    print(result.stdout)

    # Extract metrics
    try:
        lines = result.stdout.split('\n')
        precision = None
        recall = None
        f1 = None

        for i, line in enumerate(lines):
            if 'Sentence-level' in line:
                for j in range(i, min(i+10, len(lines))):
                    if 'Precision:' in lines[j]:
                        precision = float(lines[j].split('Precision:')[1].strip())
                    elif 'Recall:' in lines[j]:
                        recall = float(lines[j].split('Recall:')[1].strip())
                    elif 'F1:' in lines[j]:
                        f1_str = lines[j].split('F1:')[1].strip()
                        f1 = float(f1_str)
                        break
                break

        if f1 is not None:
            threshold_results.append({
                'threshold': threshold,
                'precision': precision,
                'recall': recall,
                'f1': f1
            })

            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
                best_precision = precision if precision else 0.0
                best_recall = recall if recall else 0.0
    except Exception as e:
        print(f"Warning: Could not parse metrics: {e}")

# Print summary
print(f"\n{'='*60}")
print("THRESHOLD COMPARISON")
print("="*60)
print(f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} {'F1':<12}")
print("-" * 60)
for result in threshold_results:
    print(f"{result['threshold']:<12.2f} {result['precision']:<12.4f} {result['recall']:<12.4f} {result['f1']:<12.4f}")

print(f"\n{'='*60}")
print("BEST RESULTS")
print("="*60)
print(f"Best Threshold: {best_threshold}")
print(f"  Precision: {best_precision:.4f} ({best_precision*100:.2f}%)")
print(f"  Recall:    {best_recall:.4f} ({best_recall*100:.2f}%)")
print(f"  F1:        {best_f1:.4f} ({best_f1*100:.2f}%)")
print(f"\nBaseline F1: 24.20%")
print(f"Improvement: {best_f1*100 - 24.20:.2f}%")
print("="*60)

EVALUATION ON DEV SET

--- Testing threshold: 0.3 ---


Generating predictions: 100%|██████████| 300/300 [00:39<00:00,  7.58it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.7994
  F1:        0.8885

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1202
  Recall:    0.2213
  F1:        0.1558

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


--- Testing threshold: 0.35 ---


Generating predictions: 100%|██████████| 300/300 [00:35<00:00,  8.39it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.7906
  F1:        0.8830

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1216
  Recall:    0.2049
  F1:        0.1526

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


--- Testing threshold: 0.4 ---


Generating predictions: 100%|██████████| 300/300 [00:36<00:00,  8.21it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.7699
  F1:        0.8700

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1183
  Recall:    0.1940
  F1:        0.1470

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


--- Testing threshold: 0.45 ---


Generating predictions: 100%|██████████| 300/300 [00:35<00:00,  8.46it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.7552
  F1:        0.8605

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1228
  Recall:    0.1913
  F1:        0.1496

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


--- Testing threshold: 0.5 ---


Generating predictions: 100%|██████████| 300/300 [00:36<00:00,  8.29it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.7345
  F1:        0.8469

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1208
  Recall:    0.1776
  F1:        0.1438

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


--- Testing threshold: 0.55 ---


Generating predictions: 100%|██████████| 300/300 [00:36<00:00,  8.25it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.6932
  F1:        0.8188

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1237
  Recall:    0.1667
  F1:        0.1420

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


--- Testing threshold: 0.6 ---


Generating predictions: 100%|██████████| 300/300 [00:36<00:00,  8.26it/s]


CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.6667
  F1:        0.8000

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1274
  Recall:    0.1639
  F1:        0.1434

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


THRESHOLD COMPARISON
Threshold    Precision    Recall       F1          
------------------------------------------------------------
0.30         0.1202       0.2213       0.1558      
0.35         0.1216       0.2049       0.1526      
0.40         0.1183       0.1940       0.1470      
0.45         0.1228       0.1913       0.1496      
0.50         0.1208       0.1776       0.1438      
0.55         0.1237       0.1667       0.1420      
0.60         0.1274       0.1639       0.1434      

BEST RESULTS
Best Threshold: 0.3
  Precision: 0

## Results Summary

### Comparison with Baseline

| Model | F1 | Precision | Recall | Notes |
|-------|-----|-----------|--------|-------|
| SciBERT Baseline | 24.20% | 19.09% | 33.06% | Milestone 2 |
| **SciBERT + GNN** | **TBD** | **TBD** | **TBD** | This extension |

### Analysis

Results will be populated after training and evaluation.

### Key Hyperparameters Used

- GNN Hidden Dim: 256
- GNN Layers: 2
- GNN Heads: 4
- Similarity Threshold: 0.3
- Hybrid Mode: True
- Learning Rate: 1e-4
- Batch Size: 4

### Expected vs Actual

- **Expected**: +2-4% F1 improvement
- **Actual**: TBD (will be updated after evaluation)
