In [1]:
import numpy as np
import re
import ast
import warnings
from typing import List, Tuple, Dict, Any
from dataclasses import dataclass
from sentence_transformers import CrossEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Define the Result Data Structure
@dataclass
class ContradictionResult:
    has_contradiction: bool
    confidence: float
    contradicting_pairs: List[Tuple[str, str]]
    explanation: str

print("Libraries imported successfully.")

  from .autonotebook import tqdm as notebook_tqdm


Libraries imported successfully.


In [2]:
class SemanticContradictionDetector:
    """
    Detects semantic contradictions within a single document using a Cross-Encoder.
    """
    
    def __init__(self, model_name: str = "default"):
        # Use the "Small" DeBERTa model for speed and accuracy
        target_model = "cross-encoder/nli-deberta-v3-small" if model_name == "default" else model_name
        
        print(f"Loading model: {target_model}...")
        self.model = CrossEncoder(target_model, device='cpu')
        
        # --- ROBUST AUTO-CALIBRATION ---
        print("Calibrating label mapping...")
        calibration_data = [
            ("The door is open.", "The door is closed."),       # Hard Contradiction
            ("The cat is sleeping.", "The animal is resting."), # Hard Entailment
        ]
        
        scores = self.model.predict(calibration_data)
        
        # Dynamically determine which index is 'Contradiction'
        self.contradiction_id = np.argmax(scores[0])
        self.entailment_id = np.argmax(scores[1])
        
        print(f"Calibration Complete. Contradiction Index: {self.contradiction_id}")
        
        # Fallback safety check
        if self.contradiction_id == self.entailment_id:
            self.contradiction_id = 2 

    def _softmax(self, x):
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()
    
    def preprocess(self, text: str) -> List[str]:
        text = text.strip()
        # Robust splitting by punctuation
        sentences = re.split(r'(?<=[.!?])\s+', text)
        
        cleaned = []
        for s in sentences:
            s = s.strip()
            if len(s) < 5: continue
            
            # --- CLEAN DISCOURSE MARKERS ---
            s = re.sub(r'^(However|But|Although|Yet),?\s*', '', s, flags=re.IGNORECASE)
            s = re.sub(r'\s*,?\s*(though|however)\.?$', '.', s, flags=re.IGNORECASE)
            cleaned.append(s)
        return cleaned
    
    def extract_claims(self, sentences: List[str]) -> List[Dict[str, Any]]:
        return [{"id": i, "text": s} for i, s in enumerate(sentences)]
    
    def check_contradiction(self, claim_a: Dict, claim_b: Dict) -> Tuple[bool, float]:
        text_a = claim_a['text']
        text_b = claim_b['text']
        
        # --- BIDIRECTIONAL CHECK ---
        inputs = [(text_a, text_b), (text_b, text_a)]
        
        scores = self.model.predict(inputs)
        probs_0 = self._softmax(scores[0])
        probs_1 = self._softmax(scores[1])
        
        score_0 = probs_0[self.contradiction_id]
        score_1 = probs_1[self.contradiction_id]
        
        max_score = max(score_0, score_1)
        
        # Threshold 0.80 filters out soft errors
        return (max_score > 0.80), float(max_score)
    
    def analyze(self, text: str) -> ContradictionResult:
        sentences = self.preprocess(text)
        if len(sentences) < 2:
            return ContradictionResult(False, 0.0, [], "Insufficient text.")
            
        claims = self.extract_claims(sentences)
        contradictions = []
        max_confidence = 0.0
        
        for i in range(len(claims)):
            for j in range(i + 1, len(claims)):
                if claims[i]['text'] == claims[j]['text']:
                    continue
                    
                is_contra, conf = self.check_contradiction(claims[i], claims[j])
                
                if is_contra:
                    contradictions.append((claims[i]['text'], claims[j]['text']))
                    max_confidence = max(max_confidence, conf)
        
        has_contradiction = len(contradictions) > 0
        explanation = f"Found {len(contradictions)} logical inconsistencies." if has_contradiction else "Consistent."
        final_conf = max_confidence if has_contradiction else (1.0 - max_confidence)
        
        return ContradictionResult(
            has_contradiction=has_contradiction,
            confidence=round(final_conf, 3),
            contradicting_pairs=contradictions,
            explanation=explanation
        )

print("Detector class defined.")

Detector class defined.


In [3]:
def evaluate(detector: SemanticContradictionDetector, test_data: List[Dict]) -> Dict[str, float]:
    y_true = []
    y_pred = []
    
    print("\n--- Detailed Evaluation ---")
    for item in test_data:
        text = item['text']
        expected = item['has_contradiction']
        
        result = detector.analyze(text)
        
        y_true.append(expected)
        y_pred.append(result.has_contradiction)
        
        status = "CORRECT" if expected == result.has_contradiction else "WRONG"
        print(f"ID {item.get('id', '?')}: {status} | Pred: {result.has_contradiction} | True: {expected}")
        
        if result.has_contradiction:
             print(f"  > Detected: {result.contradicting_pairs[0]}")
        elif status == "WRONG":
             print(f"  > Fail Context: {text[:60]}...")

    return {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0)
    }

print("Evaluation function defined.")

Evaluation function defined.


In [4]:
# Initialize detector
detector = SemanticContradictionDetector()

Loading model: cross-encoder/nli-deberta-v3-small...
Calibrating label mapping...
Calibration Complete. Contradiction Index: 0


In [5]:
try:
    with open('dataset.txt', 'r') as f:
        content = f.read()
        SAMPLE_REVIEWS = ast.literal_eval(content)
        print(f"Loaded {len(SAMPLE_REVIEWS)} samples from dataset.txt")
except FileNotFoundError:
    print("Warning: 'dataset.txt' not found. Using fallback sample data.")
    SAMPLE_REVIEWS = [
        {"id": 1, "text": "This laptop is fast. However, I wait 5 mins for chrome.", "has_contradiction": True},
        {"id": 2, "text": "Great camera. Good photos.", "has_contradiction": False},
        {"id": 3, "text": "Durable phone. Screen cracked immediately.", "has_contradiction": True}
    ]

# Show the first example
print(f"\nExample Text:\n{SAMPLE_REVIEWS[0]['text']}")

Loaded 8 samples from dataset.txt

Example Text:
This laptop is incredibly fast. Boot time is under 10 seconds. However, I find myself waiting 5 minutes just to open Chrome. The performance is unmatched in this price range.


In [6]:
# Run Analysis on all reviews
print("\n--- Running Analysis ---")
for review in SAMPLE_REVIEWS:
    result = detector.analyze(review["text"])
    print(f"Review {review['id']}: {result.has_contradiction} (Conf: {result.confidence})")

# Calculate Final Metrics
metrics = evaluate(detector, SAMPLE_REVIEWS)
print(f"\nFinal Metrics: {metrics}")


--- Running Analysis ---
Review 1: True (Conf: 0.998)
Review 2: False (Conf: 1.0)
Review 3: True (Conf: 1.0)
Review 4: True (Conf: 0.999)
Review 5: False (Conf: 1.0)
Review 6: True (Conf: 1.0)
Review 7: True (Conf: 1.0)
Review 8: False (Conf: 1.0)

--- Detailed Evaluation ---
ID 1: CORRECT | Pred: True | True: True
  > Detected: ('Boot time is under 10 seconds.', 'I find myself waiting 5 minutes just to open Chrome.')
ID 2: CORRECT | Pred: False | True: False
ID 3: CORRECT | Pred: True | True: True
  > Detected: ("I've never had a phone this durable.", 'Dropped it multiple times with no damage.')
ID 4: CORRECT | Pred: True | True: True
  > Detected: ('Customer service was unhelpful and rude.', 'They resolved my issue within minutes and even gave me a discount.')
ID 5: CORRECT | Pred: False | True: False
ID 6: CORRECT | Pred: True | True: True
  > Detected: ('Shipping was lightning fast - arrived in 2 days.', 'The three-week wait was worth it.')
ID 7: CORRECT | Pred: True | True: True
