In [7]:
from transformers import pipeline

In [8]:
oracle = pipeline("text-classification", model="facebook/bart-large-mnli", top_k=None)

Device set to use cuda:0


In [9]:
def classify_contradictions(player_action, world_facts, retrieved_memories, 
                            threshold=0.6):
    """
    Uses NLI to classify contradictions between the player action and
    each fact/memory in context. Returns ranked list of contradiction risks.
    """
    contradictions = []

    for fact in world_facts + retrieved_memories:
        text_pair = f"Premise: {fact}\nHypothesis: {player_action}"
        result = oracle(text_pair)[0]

        contradiction_score = 0.0

        for r in result:
            if r["label"].lower() == "contradiction":
                contradiction_score = r["score"]
                break

        if contradiction_score > threshold:
            contradictions.append({
                "fact": fact,
                "score": contradiction_score
            })

        contradictions.sort(key=lambda x: x["score"], reverse=True)
        return contradictions

In [19]:
import json

dataset_path = "../tests/contradiction_suite_v1.0.jsonl"

dataset = []
with open(dataset_path, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue  # skip blank lines
        try:
            dataset.append(json.loads(line))
        except json.JSONDecodeError as e:
            print(f"Skipping malformed line: {e}")

print(f"✅ Loaded {len(dataset)} test cases.\n")

success_count = 0
fail_count = 0
results_summary = []

for case in dataset:
    player_action = case.get("player_action", "")
    world_facts = case.get("world_facts", [])
    retrieved_memories = case.get("retrieved_memories", [])

    contradictions = classify_contradictions(player_action, world_facts, retrieved_memories)

    predicted = "contradiction" if contradictions else "no_contradiction"
    expected = case.get("expected_behavior", "no_contradiction")

    status = "PASS" if predicted == expected else "FAIL"
    
    if status == "PASS":
        success_count += 1
    else:
        fail_count += 1

    results_summary.append({
        "id": case["id"],
        "status": status,
        "predicted": predicted,
        "expected": expected,
        "player_action": player_action,
        "world_facts": world_facts,
        "retrieved_memories": retrieved_memories,
        "contradictions": contradictions
    })

print(f"\n===== TEST SUMMARY =====")
print(f"✅ Passed: {success_count}")
print(f"❌ Failed: {fail_count}\n")

for r in results_summary:
    print(f"{r['id']}: {r['status']}")
    print(f"  Player Action     : {r['player_action']}")
    print(f"  World Facts       : {r['world_facts']}")
    print(f"  Retrieved Memories: {r['retrieved_memories']}")
    print(f"  Expected Behavior : {r['expected']}")
    print(f"  Predicted Behavior: {r['predicted']}")
    if r['contradictions']:
        print(f"  Contradictions Detected:")
        for c in r['contradictions']:
            print(f"    - Fact: {c['fact']}, Score: {c['score']:.3f}")
    print("-" * 50)

✅ Loaded 38 test cases.


===== TEST SUMMARY =====
✅ Passed: 16
❌ Failed: 22

cs_001: PASS
  Player Action     : I attack the dragon
  World Facts       : ['The great dragon is dead']
  Retrieved Memories: ['The party killed the dragon']
  Expected Behavior : contradiction
  Predicted Behavior: contradiction
  Contradictions Detected:
    - Fact: The great dragon is dead, Score: 0.880
--------------------------------------------------
cs_002: FAIL
  Player Action     : I take the treasure
  World Facts       : ['treasure is guarded by dragon']
  Retrieved Memories: ['Dragon is asleep in the cave']
  Expected Behavior : contradiction
  Predicted Behavior: no_contradiction
--------------------------------------------------
cs_003: FAIL
  Player Action     : I ask the wizard for help
  World Facts       : ['wizard is hostile']
  Retrieved Memories: ['Wizard helped the party']
  Expected Behavior : contradiction
  Predicted Behavior: no_contradiction
---------------------------------------