In [67]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [68]:
import torch
import os
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

CUDA available: True
GPU: Tesla T4


In [69]:
!pip install -q transformers datasets jsonlines scikit-learn

In [70]:
!rm -rf cis5300_project

In [71]:
!git clone https://github.com/asxd-10/cis5300_project.git

Cloning into 'cis5300_project'...
remote: Enumerating objects: 199, done.[K
remote: Counting objects: 100% (199/199), done.[K
remote: Compressing objects: 100% (170/170), done.[K
remote: Total 199 (delta 100), reused 81 (delta 22), pack-reused 0 (from 0)[K
Receiving objects: 100% (199/199), 14.18 MiB | 15.60 MiB/s, done.
Resolving deltas: 100% (100/100), done.


In [72]:
import sys
try:
    os.chdir('cis5300_project')
    print(f"Current Working Directory changed to: {os.getcwd()}")
except FileNotFoundError:
    print("Error: The 'cis5300_project' directory was not found in the current location.")
sys.path.append('cis5300_project')
from src.common.data_utils import load_claims, load_corpus

print("Loading data")
train_claims = load_claims('data/scifact/data/claims_train.jsonl')
dev_claims = load_claims('data/scifact/data/claims_dev.jsonl')
corpus = load_corpus('data/scifact/data/corpus.jsonl')

print(f"{len(train_claims)} training claims")
print(f"{len(dev_claims)} dev claims")
print(f"{len(corpus)} documents")


Current Working Directory changed to: /content/cis5300_project/cis5300_project/cis5300_project/cis5300_project/cis5300_project
Loading data
809 training claims
300 dev claims
5183 documents


In [73]:
from src.claim_verification.model import ClaimVerifier

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = ClaimVerifier()
model = model.to(device)

print("Model loaded to GPU")


Using device: cuda
Model loaded to GPU


In [74]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn as nn

class SciFact_Dataset(Dataset):
    def __init__(self, claims, corpus, tokenizer):
        self.claims = claims
        self.corpus = corpus
        self.tokenizer = tokenizer

        # Filter claims with evidence
        self.valid_claims = [c for c in claims if c.evidence and c.label]
        print(f"  Using {len(self.valid_claims)}/{len(claims)} claims with evidence")

        self.label_map = {'SUPPORT': 0, 'CONTRADICT': 1, 'NOT_ENOUGH_INFO': 2}

    def __len__(self):
        return len(self.valid_claims)

    def __getitem__(self, idx):
        claim = self.valid_claims[idx]

        # Get first evidence doc (keep as string for evidence lookup!)
        doc_id_str = list(claim.evidence.keys())[0]  # Keep as STRING
        doc_id_int = int(doc_id_str)  # Convert to INT for corpus lookup
        doc = self.corpus[doc_id_int]

        # Create input
        text = claim.claim
        num_sents = min(len(doc.abstract), 10)
        for sent in doc.abstract[:num_sents]:
            text += " [SEP] " + sent

        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Label
        label = self.label_map.get(claim.label, 2)

        # FIXED: Create evidence mask using STRING key
        evidence_mask = torch.zeros(20)

        # Use STRING key to access evidence (CRITICAL!)
        if doc_id_str in claim.evidence:
            for ev_entry in claim.evidence[doc_id_str]:
                for sent_idx in ev_entry['sentences']:
                    if sent_idx < num_sents:
                        evidence_mask[sent_idx] = 1.0

        # Create sentence positions
        sentence_positions = torch.zeros(20, dtype=torch.long)
        claim_tokens = self.tokenizer.encode(claim.claim, add_special_tokens=True)
        current_pos = len(claim_tokens)

        for i in range(num_sents):
            sent_tokens = self.tokenizer.encode(doc.abstract[i], add_special_tokens=False)
            sentence_positions[i] = current_pos + len(sent_tokens)
            current_pos += len(sent_tokens) + 1

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'sentence_positions': sentence_positions,
            'label': torch.tensor(label),
            'evidence_mask': evidence_mask
        }

In [75]:
print("Testing fixed dataset...")
train_dataset = SciFact_Dataset(train_claims, corpus, model.tokenizer)

# Check first example
test_sample = train_dataset[0]
print(f"\nEvidence mask: {test_sample['evidence_mask'][:10]}")
print(f"Has evidence: {test_sample['evidence_mask'].sum().item() > 0}")
print(f"Num evidence sents: {int(test_sample['evidence_mask'].sum().item())}")

# Check batch
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
batch = next(iter(train_loader))
num_with_ev = (batch['evidence_mask'].sum(dim=1) > 0).sum().item()
print(f"\nBatch: {num_with_ev}/8 examples have evidence")

Testing fixed dataset...
  Using 505/809 claims with evidence

Evidence mask: tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.])
Has evidence: True
Num evidence sents: 1

Batch: 7/8 examples have evidence


In [76]:
# Create dataset
print("Creating dataset")
train_dataset = SciFact_Dataset(train_claims, corpus, model.tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

print(f"Dataset ready: {len(train_dataset)} examples, {len(train_loader)} batches")

Creating dataset
  Using 505/809 claims with evidence
Dataset ready: 505 examples, 64 batches


In [77]:
optimizer = AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

In [78]:
print("Checking evidence masks...")
sample_batch = next(iter(train_loader))

print(f"\nBatch shapes:")
print(f"  evidence_mask: {sample_batch['evidence_mask'].shape}")
print(f"  sentence_positions: {sample_batch['sentence_positions'].shape}")

print(f"\nFirst example:")
print(f"  Evidence mask: {sample_batch['evidence_mask'][0]}")
print(f"  Sentence positions: {sample_batch['sentence_positions'][0]}")

# Count how many have evidence
num_with_evidence = (sample_batch['evidence_mask'].sum(dim=1) > 0).sum().item()
print(f"\nExamples with evidence: {num_with_evidence}/{sample_batch['evidence_mask'].size(0)}")

# Check if sentence positions are reasonable
print(f"Sentence position range: {sample_batch['sentence_positions'].min()}-{sample_batch['sentence_positions'].max()}")

Checking evidence masks...

Batch shapes:
  evidence_mask: torch.Size([8, 20])
  sentence_positions: torch.Size([8, 20])

First example:
  Evidence mask: tensor([1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])
  Sentence positions: tensor([ 46, 103, 178, 227, 275, 335, 385,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0])

Examples with evidence: 5/8
Sentence position range: 0-385


In [79]:
print("\n" + "="*60)
print("TRAINING")

model.train()
num_epochs = 8

for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in progress_bar:
        # Move to GPU
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        sentence_positions = batch['sentence_positions'].to(device)
        labels = batch['label'].to(device)
        evidence_mask = batch['evidence_mask'].to(device)

        # Forward with evidence prediction
        label_logits, evidence_logits = model(input_ids, attention_mask, sentence_positions)

        # Multi-task loss
        label_loss = criterion(label_logits, labels)
        evidence_loss = nn.BCEWithLogitsLoss()(evidence_logits, evidence_mask)
        loss = label_loss + 2.0 * evidence_loss  # 0.5 or 1.0 weight on evidence

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Stats
        total_loss += loss.item()
        pred = label_logits.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

        evidence_pred = (torch.sigmoid(evidence_logits) > 0.5).float()
        # Only count non-padding positions
        valid_positions = (sentence_positions > 0).float() # All positions
        correct_per_example = ((evidence_pred == evidence_mask) * valid_positions).sum(dim=1)
        total_per_example = valid_positions.sum(dim=1)
        evidence_acc = (correct_per_example / total_per_example).mean().item()

        # Update progress bar
        progress_bar.set_postfix({
          'loss': f'{loss.item():.4f}',
          'acc': f'{100*correct/total:.1f}%',
          'ev_acc': f'{100*evidence_acc:.1f}%'
      })



    avg_loss = total_loss / len(train_loader)
    accuracy = 100 * correct / total

    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Accuracy: {accuracy:.2f}%")

    # Save checkpoint to Google Drive
    checkpoint_path = f'models/claim_verifier/model_epoch{epoch+1}.pt'
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    torch.save(model.state_dict(), checkpoint_path)
    print(f"  Saved: {checkpoint_path}")

print("\nTraining complete!")



TRAINING


Epoch 1/8: 100%|██████████| 64/64 [00:53<00:00,  1.20it/s, loss=1.5078, acc=63.8%, ev_acc=87.5%]



Epoch 1 Summary:
  Loss: 1.8233
  Accuracy: 63.76%
  Saved: models/claim_verifier/model_epoch1.pt


Epoch 2/8: 100%|██████████| 64/64 [00:52<00:00,  1.21it/s, loss=1.6988, acc=65.7%, ev_acc=62.5%]



Epoch 2 Summary:
  Loss: 1.7307
  Accuracy: 65.74%
  Saved: models/claim_verifier/model_epoch2.pt


Epoch 3/8: 100%|██████████| 64/64 [00:52<00:00,  1.21it/s, loss=2.2979, acc=67.5%, ev_acc=90.0%]



Epoch 3 Summary:
  Loss: 1.6591
  Accuracy: 67.52%
  Saved: models/claim_verifier/model_epoch3.pt


Epoch 4/8: 100%|██████████| 64/64 [00:52<00:00,  1.22it/s, loss=1.3389, acc=73.5%, ev_acc=50.0%]



Epoch 4 Summary:
  Loss: 1.5087
  Accuracy: 73.47%
  Saved: models/claim_verifier/model_epoch4.pt


Epoch 5/8: 100%|██████████| 64/64 [00:52<00:00,  1.21it/s, loss=1.0781, acc=81.6%, ev_acc=90.0%]



Epoch 5 Summary:
  Loss: 1.3201
  Accuracy: 81.58%
  Saved: models/claim_verifier/model_epoch5.pt


Epoch 6/8: 100%|██████████| 64/64 [00:53<00:00,  1.20it/s, loss=4.5197, acc=86.7%, ev_acc=87.5%]



Epoch 6 Summary:
  Loss: 1.2565
  Accuracy: 86.73%
  Saved: models/claim_verifier/model_epoch6.pt


Epoch 7/8: 100%|██████████| 64/64 [00:52<00:00,  1.22it/s, loss=1.0620, acc=91.3%, ev_acc=85.7%]



Epoch 7 Summary:
  Loss: 1.0708
  Accuracy: 91.29%
  Saved: models/claim_verifier/model_epoch7.pt


Epoch 8/8: 100%|██████████| 64/64 [00:53<00:00,  1.21it/s, loss=1.0213, acc=93.9%, ev_acc=100.0%]



Epoch 8 Summary:
  Loss: 0.9679
  Accuracy: 93.86%
  Saved: models/claim_verifier/model_epoch8.pt

Training complete!


In [80]:
print("\n" + "="*60)
print("EVALUATION ON DEV SET")
print("="*60)

model.eval()
dev_dataset = SciFact_Dataset(dev_claims, corpus, model.tokenizer)
dev_loader = DataLoader(dev_dataset, batch_size=16, shuffle=False)

correct = 0
total = 0

with torch.no_grad():
    for batch in tqdm(dev_loader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        label_logits, _ = model(input_ids, attention_mask, sentence_positions=None)
        pred = label_logits.argmax(dim=1)

        correct += (pred == labels).sum().item()
        total += labels.size(0)

accuracy = 100 * correct / total
print(f"\nDev Set Label Accuracy: {accuracy:.2f}%")
print(f"(Simple baseline was close to 0%, so anything is better!)")


EVALUATION ON DEV SET
  Using 188/300 claims with evidence


Evaluating: 100%|██████████| 12/12 [00:06<00:00,  1.88it/s]


Dev Set Label Accuracy: 64.89%
(Simple baseline was close to 0%, so anything is better!)





In [81]:
# THRESHOLD EXPERIMENT
print("\n" + "="*60)
print("TESTING DIFFERENT THRESHOLDS")
print("="*60)

for threshold in [0.30, 0.35, 0.40, 0.50, 0.55]:
    print(f"\n--- Threshold: {threshold} ---")

    model.eval()
    predictions = []

    with torch.no_grad():
        for claim in tqdm(dev_claims, desc=f"Threshold {threshold}", leave=False):
            if not hasattr(claim, 'cited_doc_ids') or not claim.cited_doc_ids:
                predictions.append({'id': claim.id, 'label': 'NOT_ENOUGH_INFO', 'evidence': {}})
                continue

            doc_id = int(claim.cited_doc_ids[0])
            if doc_id not in corpus:
                predictions.append({'id': claim.id, 'label': 'NOT_ENOUGH_INFO', 'evidence': {}})
                continue

            doc = corpus[doc_id]
            text = claim.claim
            num_sents = min(len(doc.abstract), 10)
            for sent in doc.abstract[:num_sents]:
                text += " [SEP] " + sent

            encoding = model.tokenizer(text, max_length=512, padding='max_length',
                                      truncation=True, return_tensors='pt').to(device)

            sentence_positions = torch.zeros(1, 20, dtype=torch.long).to(device)
            claim_tokens = model.tokenizer.encode(claim.claim, add_special_tokens=True)
            current_pos = len(claim_tokens)

            for i in range(num_sents):
                sent_tokens = model.tokenizer.encode(doc.abstract[i], add_special_tokens=False)
                sentence_positions[0, i] = current_pos + len(sent_tokens)
                current_pos += len(sent_tokens) + 1

            label_logits, evidence_logits = model(encoding['input_ids'],
                                                   encoding['attention_mask'],
                                                   sentence_positions)

            pred_label_idx = label_logits.argmax(dim=1).item()
            label_map = {0: 'SUPPORT', 1: 'CONTRADICT', 2: 'NOT_ENOUGH_INFO'}
            pred_label = label_map[pred_label_idx]

            evidence_probs = torch.sigmoid(evidence_logits[0])
            pred_evidence_sents = [i for i, prob in enumerate(evidence_probs[:num_sents])
                                   if prob > threshold]

            prediction = {'id': claim.id, 'label': pred_label, 'evidence': {}}
            if pred_evidence_sents:
                prediction['evidence'][str(doc_id)] = [{
                    'sentences': pred_evidence_sents,
                    'label': pred_label
                }]
            else:
                prediction['label'] = 'NOT_ENOUGH_INFO'

            predictions.append(prediction)

    # Save and evaluate
    output_path = f'output/dev/scibert_thresh{int(threshold*100)}.jsonl'
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with jsonlines.open(output_path, 'w') as writer:
        writer.write_all(predictions)

    # Quick eval
    !python src/evaluation/score_claims.py \
      --gold data/scifact/data/claims_dev.jsonl \
      --predictions {output_path}


TESTING DIFFERENT THRESHOLDS

--- Threshold: 0.3 ---




CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.8230
  F1:        0.9029

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1779
  Recall:    0.3607
  F1:        0.2383

Label-only:
  Accuracy:  0.0000

Interpretation:
 Retrieval is excellent (oracle or near-oracle)
  Evidence extraction is improving but below target

--- Threshold: 0.35 ---




CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.8142
  F1:        0.8976

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1822
  Recall:    0.3579
  F1:        0.2415

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target

--- Threshold: 0.4 ---




CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.8083
  F1:        0.8940

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1838
  Recall:    0.3525
  F1:        0.2416

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target

--- Threshold: 0.5 ---




CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.8024
  F1:        0.8903

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1863
  Recall:    0.3333
  F1:        0.2390

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target

--- Threshold: 0.55 ---




CLAIM VERIFICATION EVALUATION

Loading data...
  Gold claims: 300
  Predictions: 300

Computing metrics...

RESULTS

Abstract-level (Retrieval):
  Precision: 1.0000
  Recall:    0.8024
  F1:        0.8903

Sentence-level (Evidence + Label): PRIMARY METRIC
  Precision: 0.1909
  Recall:    0.3306
  F1:        0.2420

Label-only:
  Accuracy:  0.0000

Interpretation:
  Retrieval is working reasonably
  Evidence extraction is improving but below target


In [82]:
import jsonlines

# Check gold format
with jsonlines.open('data/scifact/data/claims_dev.jsonl') as f:
    for claim in f:
        if 'evidence' in claim and claim['evidence']:
            print("Sample gold evidence:")
            print(f"  Keys: {list(claim['evidence'].keys())}")
            print(f"  Key type: {type(list(claim['evidence'].keys())[0])}")
            print(f"  Full evidence: {claim['evidence']}")
            break

# # Check your prediction format
# with jsonlines.open('output/dev/scibert_predictions.jsonl') as f:
#     for pred in f:
#         if pred['evidence']:
#             print("\nSample prediction evidence:")
#             print(f"  Keys: {list(pred['evidence'].keys())}")
#             print(f"  Key type: {type(list(pred['evidence'].keys())[0])}")
#             print(f"  Full evidence: {pred['evidence']}")
#             break

Sample gold evidence:
  Keys: ['14717500']
  Key type: <class 'str'>
  Full evidence: {'14717500': [{'sentences': [2, 5], 'label': 'SUPPORT'}, {'sentences': [7], 'label': 'SUPPORT'}]}


In [83]:
# Debug: Check what labels you're actually predicting
# with jsonlines.open('output/dev/scibert_predictions.jsonl') as f:
#     sample_preds = [p for p in f][:5]
#     print("\nSample predictions:")
#     for p in sample_preds:
#         print(f"  ID {p['id']}: label='{p['label']}', evidence={bool(p['evidence'])}")

# Check gold format
with jsonlines.open('data/scifact/data/claims_dev.jsonl') as f:
    sample_gold = [c for c in f][:5]
    print("\nSample gold:")
    for g in sample_gold:
        print(f"  ID {g['id']}: label='{g.get('label', 'MISSING')}'")


Sample gold:
  ID 1: label='MISSING'
  ID 3: label='MISSING'
  ID 5: label='MISSING'
  ID 13: label='MISSING'
  ID 36: label='MISSING'
