In [None]:
# This script was run in a Google Colab environment.

!pip install -q transformers  evaluate
!pip install -U datasets
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import evaluate



In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
dataset = load_dataset("json", data_files={
    "train": "/content/drive/MyDrive/cuad_project/cuad_qa_train.json",
    "validation": "/content/drive/MyDrive/cuad_project/cuad_qa_test.json"
})


In [None]:
import json
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    TrainingArguments,
    Trainer,
    TrainerCallback,
)
from datasets import load_dataset, DatasetDict
import evaluate
import random

# === Load CUAD-style dataset from JSON
dataset = load_dataset("json", data_files={
    "train": "/content/drive/MyDrive/cuad_project/cuad_qa_train.json",
    "validation": "/content/drive/MyDrive/cuad_project/cuad_qa_test.json"
})

print("✅ Dataset loaded:", dataset)
print("🧪 Train size:", len(dataset["train"]), "| Val size:", len(dataset["validation"]))

# === Load tokenizer and model
model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# === Preprocessing with overflow and keeping answers
def preprocess(example):
    tokenized = tokenizer(
        example["question"],
        example["context"],
        truncation="only_second",
        padding="max_length",
        max_length=384,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True
    )

    sample_count = len(tokenized["input_ids"])

    # Repeat other fields to match length
    tokenized["answers"] = [example["answers"]] * sample_count
    tokenized["context"] = [example["context"]] * sample_count
    tokenized["question"] = [example["question"]] * sample_count

    return tokenized

tokenized_dataset = dataset.map(preprocess, batched=True)

# === Span alignment
def add_token_positions(example):
    start_char = example["answers"]["answer_start"]
    end_char = start_char + len(example["answers"]["text"])
    offsets = example["offset_mapping"]

    start_token = end_token = None
    for idx, (s, e) in enumerate(offsets):
        if s <= start_char < e:
            start_token = idx
        if s < end_char <= e:
            end_token = idx
            break

    if start_token is None or end_token is None:
        return {}  # skip
    example["start_positions"] = start_token
    example["end_positions"] = end_token
    return example

tokenized_dataset = tokenized_dataset.map(add_token_positions, batched=False)

# === Filter and cleanup
tokenized_dataset = tokenized_dataset.filter(lambda ex: "start_positions" in ex and "end_positions" in ex)
tokenized_dataset = tokenized_dataset.remove_columns(["offset_mapping", "answers"])

# === Evaluation metrics
squad_metric = evaluate.load("squad")
def compute_metrics(eval_pred):
    start_logits, end_logits = eval_pred.predictions
    predictions = []
    references = []

    for i, example in enumerate(tokenized_dataset["validation"]):
        input_ids = example["input_ids"]
        start = torch.argmax(torch.tensor(start_logits[i])).item()
        end = torch.argmax(torch.tensor(end_logits[i])).item()
        if end < start:
            end = start
        pred_answer = tokenizer.decode(input_ids[start:end+1], skip_special_tokens=True)

        predictions.append({"id": str(i), "prediction_text": pred_answer})
        references.append({"id": str(i), "answers": {"text": [example["context"][example['start_positions']:example['end_positions']]], "answer_start": [0]}})
    return squad_metric.compute(predictions=predictions, references=references)

# === Logging predictions after each epoch
class PrintPredictionsCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        print("\n📝 Sample Predictions:")
        val_set = tokenized_dataset["validation"]
        indices = random.sample(range(len(val_set)), k=5)
        for idx in indices:
            ex = val_set[idx]
            inputs = tokenizer(
                ex["question"],
                ex["context"],
                truncation="only_second",
                max_length=384,
                return_tensors="pt"
            )
            with torch.no_grad():
                outputs = model(**inputs)
                start = torch.argmax(outputs.start_logits).item()
                end = torch.argmax(outputs.end_logits).item()
                if end < start:
                    end = start
                pred_answer = tokenizer.decode(inputs["input_ids"][0][start:end+1], skip_special_tokens=True)

            print(f"\n📌 Q: {ex['question']}")
            print(f"🤖 Predicted: {pred_answer}")
            print(f"✅ Ground Truth (tokens): {tokenizer.decode(ex['input_ids'][ex['start_positions']:ex['end_positions']+1], skip_special_tokens=True)}")

# === Training config
args = TrainingArguments(
    output_dir="/content/cuad_roberta",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    learning_rate=2e-5,
    warmup_steps=100,
    weight_decay=0.01,
    fp16=torch.cuda.is_available(),
    logging_steps=100,
    report_to="none"
)

# === Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[PrintPredictionsCallback()]
)

# === Train!
trainer.train()

# === Save
trainer.save_model("/content/drive/MyDrive/cuad_project/cuad_model")
tokenizer.save_pretrained("/content/drive/MyDrive/cuad_project/cuad_model")


In [None]:
model.save_pretrained("/content/drive/MyDrive/cuad_project/cuad_model")
tokenizer.save_pretrained("/content/drive/MyDrive/cuad_project/cuad_model")


In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

# === Load test data ===
with open("/content/drive/MyDrive/cuad_project/cuad_qa_test.json", "r", encoding="utf-8") as f:
    test_data = json.load(f)

# === Load trained model and tokenizer ===
model_path = "/content/drive/MyDrive/cuad_project/cuad_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForQuestionAnswering.from_pretrained(model_path)
model.eval()

# === Clause extraction logic ===
def extract_clause(question, context, model, tokenizer, max_length=512):
    inputs = tokenizer(
        question,
        context,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )

    print("\n🔍 Tokenized input (first 50 tokens):")
    print(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])[:50])
    print("📏 Total input length:", len(inputs["input_ids"][0]))

    with torch.no_grad():
        outputs = model(**inputs)

    start = torch.argmax(outputs.start_logits).item()
    end = torch.argmax(outputs.end_logits).item()

    print(f"\n🔁 Predicted token span: {start} to {end}")
    print("🔢 Start logits (top 5):", torch.topk(outputs.start_logits, 5).values.tolist())
    print("🔢 End logits (top 5):", torch.topk(outputs.end_logits, 5).values.tolist())

    if start >= len(inputs["input_ids"][0]) or end >= len(inputs["input_ids"][0]):
        return "[Invalid span: out of bounds]"
    if end < start:
        return "[Invalid span: end before start]"

    tokens = inputs["input_ids"][0][start:end + 1]
    return tokenizer.decode(tokens, skip_special_tokens=True)


# === Choose contract ID (trimmed) ===
contract_id = "DovaPharmaceuticalsInc_20181108_10-Q_EX-10.2_11414857_EX-10.2_Promotion Agreement"

# === Filter questions for this contract ===
contract_entries = [e for e in test_data if e["id"].startswith(contract_id)]
print(f"\n📄 Found {len(contract_entries)} questions for contract: {contract_id}\n")

if not contract_entries:
    print("⚠️ No entries found. Check contract_id formatting.")
    exit()

# === Show clause types
for i, e in enumerate(contract_entries):
    print(f"{i + 1}. {e['id'].split('__')[-1]}")

# === Choose one to evaluate
q_index = int(input("\n🎯 Choose a question index to run (1-based): ")) - 1
question = contract_entries[q_index]["question"]
context = contract_entries[q_index]["context"]

print("\n🧠 Question:")
print(question)

print("\n📜 Context (first 500 chars):")
print(context[:500])

# === Model inference
answer = extract_clause(question, context, model, tokenizer)

# === Output results
print("\n🤖 Model's Answer:")
print(answer)

print("\n✅ Ground Truth:")
gt_answer = contract_entries[q_index].get("answers", {}).get("text", "[No labeled answer]")
print(gt_answer)
