In [7]:
import spacy
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import networkx as nx
from dateutil.parser import parse as date_parse
import re
from typing import List, Dict, Tuple

# Load models
nlp = spacy.load("en_core_web_trf")
legal_bert_tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
legal_bert_model = AutoModelForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased")
legal_classifier = pipeline("text-classification", model=legal_bert_model, tokenizer=legal_bert_tokenizer)

class LegalFactChecker:
    def __init__(self):
        self.legal_kb = self.load_legal_kb()

    def load_legal_kb(self) -> Dict:
        # Simplified legal knowledge base
        return {
            "legal_aid": {
                "entities": ["Scottish Legal Aid Board"],
                "procedures": ["application", "refusal", "granting"],
                "conditions": ["financial eligibility", "merits test"]
            },
            "appeal": {
                "entities": ["High Court of Justiciary", "Appeal Court"],
                "procedures": ["lodging", "hearing", "decision"],
                "grounds": ["miscarriage of justice", "error in law"]
            }
        }

    def extract_facts(self, text: str) -> List[Dict]:
        doc = nlp(text)
        facts = []
        for sent in doc.sents:
            for token in sent:
                if token.dep_ == "ROOT" and token.pos_ == "VERB":
                    fact = {
                        'predicate': token.lemma_,
                        'subject': ' '.join([t.text for t in token.subtree if t.dep_ == "nsubj"]),
                        'object': ' '.join([t.text for t in token.subtree if t.dep_ in ["dobj", "pobj"]]),
                        'text': sent.text
                    }
                    facts.append(fact)
        return facts

    def extract_entities(self, text: str) -> Dict[str, List[str]]:
        doc = nlp(text)
        entities = {}
        for ent in doc.ents:
            if ent.label_ not in entities:
                entities[ent.label_] = []
            entities[ent.label_].append(ent.text)
        return entities

    def extract_dates(self, text: str) -> List[Tuple[str, str]]:
        date_pattern = r'\b(\d{1,2}\s+\w+\s+\d{4}|\d{4})\b'
        matches = re.findall(date_pattern, text)
        return [(match, str(date_parse(match))) for match in matches]

    def build_event_graph(self, facts: List[Dict]) -> nx.DiGraph:
        G = nx.DiGraph()
        for i, fact in enumerate(facts):
            G.add_node(i, **fact)
            for j, other_fact in enumerate(facts):
                if i != j:
                    if fact['subject'] in other_fact['text'] or fact['object'] in other_fact['text']:
                        G.add_edge(i, j)
        return G

    def check_temporal_consistency(self, dates: List[Tuple[str, str]]) -> List[str]:
        inconsistencies = []
        sorted_dates = sorted(dates, key=lambda x: date_parse(x[1]))
        for i in range(len(sorted_dates) - 1):
            if (date_parse(sorted_dates[i+1][1]) - date_parse(sorted_dates[i][1])).days < 0:
                inconsistencies.append(f"Temporal inconsistency: {sorted_dates[i][0]} is after {sorted_dates[i+1][0]}")
        return inconsistencies

    def validate_legal_claims(self, facts: List[Dict]) -> List[str]:
        invalid_claims = []
        for fact in facts:
            legal_context = legal_classifier(fact['text'])[0]
            if legal_context['label'] in self.legal_kb:
                kb_entry = self.legal_kb[legal_context['label']]
                if not any(entity in fact['text'] for entity in kb_entry['entities']):
                    invalid_claims.append(f"Missing legal entity in claim: {fact['text']}")
                if not any(proc in fact['text'] for proc in kb_entry['procedures']):
                    invalid_claims.append(f"Invalid legal procedure in claim: {fact['text']}")
        return invalid_claims

    def compare_facts(self, original_facts: List[Dict], generated_facts: List[Dict]) -> List[str]:
        fabrications = []
        for gen_fact in generated_facts:
            if not any(self.fact_similarity(gen_fact, orig_fact) > 0.8 for orig_fact in original_facts):
                fabrications.append(f"Potential fabrication: {gen_fact['text']}")
        return fabrications

    def fact_similarity(self, fact1: Dict, fact2: Dict) -> float:
        # Simplified similarity measure
        elements1 = set([fact1['predicate'], fact1['subject'], fact1['object']])
        elements2 = set([fact2['predicate'], fact2['subject'], fact2['object']])
        return len(elements1 & elements2) / len(elements1 | elements2)

    def check_fabrication(self, original_text: str, generated_text: str) -> Dict:
        original_facts = self.extract_facts(original_text)
        generated_facts = self.extract_facts(generated_text)
        original_entities = self.extract_entities(original_text)
        generated_entities = self.extract_entities(generated_text)
        original_dates = self.extract_dates(original_text)
        generated_dates = self.extract_dates(generated_text)

        fact_fabrications = self.compare_facts(original_facts, generated_facts)
        entity_fabrications = {k: list(set(v) - set(original_entities.get(k, []))) for k, v in generated_entities.items()}
        temporal_inconsistencies = self.check_temporal_consistency(generated_dates)
        invalid_legal_claims = self.validate_legal_claims(generated_facts)

        original_graph = self.build_event_graph(original_facts)
        generated_graph = self.build_event_graph(generated_facts)
        structural_inconsistencies = [
            f"Structural inconsistency in event {i}" 
            for i in range(len(generated_facts)) 
            if i in generated_graph and generated_graph.out_degree(i) != original_graph.out_degree(i)
        ]

        return {
            "fact_fabrications": fact_fabrications,
            "entity_fabrications": entity_fabrications,
            "temporal_inconsistencies": temporal_inconsistencies,
            "invalid_legal_claims": invalid_legal_claims,
            "structural_inconsistencies": structural_inconsistencies
        }


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [11]:
from utils import load_json, load_txt
original_text = load_json("../bigger_study_sample/001-57899.json")['facts']
generated_text = load_txt("./gpt-4/001-57899.txt")

In [12]:
# Usage
checker = LegalFactChecker()


results = checker.check_fabrication(original_text, generated_text)

for category, issues in results.items():
    if issues:
        print(f"\n{category.replace('_', ' ').title()}:")
        for issue in issues:
            print(f"- {issue}")


Fact Fabrications:
- Potential fabrication: **Answer:**  
The case involves Mr. Anthony Boner, a British citizen, born in 1960, who was convicted of multiple criminal offenses, including assault and armed robbery, in a trial held in the High Court of Justiciary, Scotland, between March 29 and April 10, 1990.
- Potential fabrication: 
   - Mr. Boner, along with two others, was arrested following an investigation and was charged with assault, armed robbery, wilful damage, and firearm-related offenses.
- Potential fabrication: - During the trial, a witness, Mrs. G., was allowed to give evidence after being present in the courtroom before her testimony, which the defense objected to.
- Potential fabrication: The trial judge ruled that her earlier presence did not affect the fairness of her testimony.
- Potential fabrication: Mrs. G.'s evidence implicated Mr. Boner in the robbery.
- Potential fabrication: 
   - The jury found Mr. Boner guilty of all charges, and he was sentenced to eight y