In [None]:
!pip install -q pathway sentence-transformers transformers torch pandas numpy scikit-learn tqdm


In [3]:
import pathway as pw
import pandas as pd
import numpy as np
import glob
import torch
from tqdm.auto import tqdm

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification



In [16]:
# =========================
# Mount Google Drive
# =========================
from google.colab import drive
drive.mount('/content/drive')

# =========================
# File Paths (Colab Drive)
# =========================
DATA_DIR = '/content/drive/MyDrive/Dataset'

TEXT_DATA_GLOB = f'{DATA_DIR}/Books/*.txt'
TRAIN_PATH = f'{DATA_DIR}/train.csv'
TEST_PATH = f'{DATA_DIR}/test.csv'
EMBEDDINGS_PATH = f'/content/drive/MyDrive/chunk_embeddings.npy'
CHUNK_TEXT_PATH = f'/content/drive/MyDrive/chunk_texts.txt'
RESULTS_PATH = f'/content/drive/MyDrive/results.csv'


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
def clean_text(text: str) -> str:
    return text.replace("\n\n", "\n").strip()

def build_backstory(row):
    return f"""
Character: {row['char']}
Context: {row['caption']}
Backstory:
{row['content']}
"""

def extract_claims(backstory: str, max_claims=4):
    sentences = [
        s.strip()
        for s in backstory.replace("\n", " ").split(".")
        if len(s.strip()) > 25
    ]

    claims = []
    for s in sentences:
        if any(
            kw in s.lower()
            for kw in ["was", "were", "became", "led", "joined", "escaped", "imprisoned"]
        ):
            claims.append(s)

    return claims[:max_claims]

def explain_verdict(verdict: str, claim: str, excerpt: str) -> str:
    if verdict == "SUPPORTS":
        return "This excerpt supports the claim by directly aligning with it."
    elif verdict == "CONTRADICTS":
        return "This excerpt contradicts the claim by presenting conflicting information."
    else:
        return "This excerpt does not clearly support or contradict the claim."


In [18]:
def chunk_text(text, max_chars=1200, overlap=200):
    chunks = []
    start = 0
    while start < len(text):
        end = start + max_chars
        chunks.append(text[start:end])
        start += max_chars - overlap
    return chunks


In [19]:
novels = pw.io.fs.read(TEXT_DATA_GLOB, format="plaintext_by_file")


novels = novels.with_columns(
    clean_text=pw.apply(clean_text, novels.data)
)

novels_selected = novels.select(
    text=novels.clean_text
)

pw.run()


Output()

In [20]:
texts = []
for path in glob.glob(TEXT_DATA_GLOB):
    with open(path, "r", encoding="utf-8") as f:
        texts.append(clean_text(f.read()))


chunk_texts = []
for text in texts:
    chunk_texts.extend(
        c for c in chunk_text(text)
        if len(c) > 200
    )

print("Total chunks:", len(chunk_texts))


Total chunks: 3453


In [None]:
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

embeddings = []
batch_size = 32

for i in range(0, len(chunk_texts), batch_size):
    batch = chunk_texts[i:i + batch_size]
    emb = embed_model.encode(batch, normalize_embeddings=True)
    embeddings.append(emb)

chunk_embeddings = np.vstack(embeddings)


In [None]:
np.save(EMBEDDINGS_PATH, chunk_embeddings)

with open(CHUNK_TEXT_PATH, "w", encoding="utf-8") as f:
    for c in chunk_texts:
        f.write(c.replace("\n", " ") + "\n---\n")


In [22]:
chunk_embeddings = np.load(EMBEDDINGS_PATH)

chunk_texts = []
with open(CHUNK_TEXT_PATH, "r", encoding="utf-8") as f:
    current = []
    for line in f:
        if line.strip() == "---":
            chunk_texts.append(" ".join(current))
            current = []
        else:
            current.append(line.strip())



In [23]:
def knn_search(query_emb, k=4):
    sims = np.dot(chunk_embeddings, query_emb)
    top_idx = np.argsort(sims)[-k:][::-1]
    return [chunk_texts[i] for i in top_idx]


In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
nli_model = AutoModelForSequenceClassification.from_pretrained(
    "facebook/bart-large-mnli"
).to(device)

nli_model.eval()

LABEL_MAP = {
    0: "CONTRADICTS",
    1: "INSUFFICIENT",
    2: "SUPPORTS"
}


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

In [25]:
def nli_verdict(claim, evidence_chunks):
    if not evidence_chunks:
        return "INSUFFICIENT", 0.0

    premise = " ".join(evidence_chunks[:2])
    premise = premise[:1500]

    hypothesis = claim

    inputs = tokenizer(
        premise,
        hypothesis,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(device)

    with torch.no_grad():
        logits = nli_model(**inputs).logits
        probs = torch.softmax(logits, dim=1)[0]
        pred = torch.argmax(probs).item()

    return LABEL_MAP[pred], probs[pred].item()


In [26]:
def aggregate_verdicts(verdicts_with_conf):
    supports = 0
    insufficient = 0
    contradicts = 0

    for verdict, conf in verdicts_with_conf:
        if verdict == "SUPPORTS":
            supports += 1
        elif verdict == "CONTRADICTS" and conf > 0.55:
            contradicts += 1
        else:
            insufficient += 1

    total = supports + insufficient

    # Any clear contradiction → inconsistent
    if contradicts >= 1:
        return 0

    # Too many unsupported claims → inconsistent
    if total >= 2 and supports / total < 0.34:
        return 0

    return 1


In [27]:
def build_rationale(dossier, final_label):
    supporting_excerpts = []
    contradicting_excerpts = []
    neutral_excerpts = []

    for d in dossier:
        for e in d.get("evidence", []):
            excerpt = e.get("excerpt", "").strip()
            if not excerpt:
                continue

            if d["verdict"] == "SUPPORTS":
                supporting_excerpts.append(excerpt)
            elif d["verdict"] == "CONTRADICTS":
                contradicting_excerpts.append(excerpt)
            else:
                neutral_excerpts.append(excerpt)

    # --- INCONSISTENT ---
    if final_label == 0:
        if contradicting_excerpts:
            ex = contradicting_excerpts[0][:250]
            return (
                "The backstory is inconsistent with the novel. "
                "Retrieved passages contradict key elements of the backstory. "
                f"For example: \"{ex}...\""
            )
        else:
            return (
                "The backstory is inconsistent with the novel. "
                "The retrieved passages conflict with the described events or character details."
            )

    # --- CONSISTENT ---
    else:
        if supporting_excerpts:
            ex = supporting_excerpts[0][:250]
            return (
                "The backstory is consistent with the novel. "
                "Retrieved passages partially support the described events. "
                f"For example: \"{ex}...\""
            )

        if neutral_excerpts:
            ex = neutral_excerpts[0][:200]
            return (
                "The backstory is consistent with the novel. "
                "Retrieved passages mention related characters or events and do not contradict the backstory. "
                f"For example: \"{ex}...\""
            )

        return (
            "The backstory is consistent with the novel. "
            "No retrieved passages contradict the described events or character details."
        )


In [28]:
def predict_consistency(backstory: str):
    claims = extract_claims(backstory)

    verdicts_with_conf = []
    dossier = []

    for claim in claims:
        claim_emb = embed_model.encode(
            [claim], normalize_embeddings=True
        )[0]

        evidence_chunks = knn_search(claim_emb, k=5)

        verdict, conf = nli_verdict(claim, evidence_chunks)
        verdicts_with_conf.append((verdict, conf))  # ✅ FIX

        dossier.append({
            "claim": claim,
            "verdict": verdict,
            "confidence": conf,
            "evidence": [
                {
                    "excerpt": e,
                    "analysis": explain_verdict(verdict, claim, e)
                } for e in evidence_chunks
            ]
        })

    final_label = aggregate_verdicts(verdicts_with_conf)
    rationale = build_rationale(dossier, final_label)
    return final_label, rationale


In [29]:
train_df = pd.read_csv(TRAIN_PATH)
test_df = pd.read_csv(TEST_PATH)



In [30]:
results = []

for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    backstory = build_backstory(row)
    prediction, rationale = predict_consistency(backstory)

    results.append({
        "story_id": row["id"],
        "prediction": prediction,
        "rationale": rationale
    })

results_df = pd.DataFrame(results)
results_df.to_csv(RESULTS_PATH, index=False)

print("Saved results.csv with", len(results_df), "rows")


  0%|          | 0/60 [00:00<?, ?it/s]

Saved results.csv with 60 rows


In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm.auto import tqdm

y_true = []
y_pred = []

for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Evaluating on train set"):
    backstory = build_backstory(row)
    prediction, _ = predict_consistency(backstory)

    gold = 1 if row["label"].strip().lower() == "consistent" else 0

    y_true.append(gold)
    y_pred.append(prediction)

acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(
    y_true, y_pred, average="binary"
)

print("==== TRAIN SET METRICS ====")
print(f"Accuracy : {acc:.3f}")
print(f"Precision: {prec:.3f}")
print(f"Recall   : {rec:.3f}")
print(f"F1 Score : {f1:.3f}")


Evaluating on train set:   0%|          | 0/80 [00:00<?, ?it/s]

==== TRAIN SET METRICS ====
Accuracy : 0.600
Precision: 0.627
Recall   : 0.922
F1 Score : 0.746
