In [2]:
import re
import os
import json
from itertools import islice
from rapidfuzz import fuzz
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict

In [5]:
class DataCleaner:
    def __init__(self, input_file=""):
        self.input_file = input_file
        self.output_file = "final_training_data.jsonl"
        
        # Loading a small, fast embedding model for consensus check
        print("Loading Sentence Transformer model...")
        # self.encoder = SentenceTransformer("all-MiniLM-L6-v2")
        self.encoder = SentenceTransformer("Qwen/Qwen3-Embedding-8B")
        
    def _parse_json_safely(self, json_str):
        """
        Robustly parses JSON even if the LLM wraps it in markdown blocks
        or adds extra text.
        """
        if not json_str:
            return None

        # 1. Strip Markdown Code Blocks (```json ... ```)
        # This regex looks for the content inside ``` ... ``` tags
        match = re.search(r"```(?:json)?\s*(.*)\s*```", json_str, re.DOTALL)
        if match:
            json_str = match.group(1)
        
        # 2. Try parsing cleaned string
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            pass # Parsing failed, try the backup method below

        # 3. BACKUP: Regex extraction (If LLM messed up commas/quotes)
        # Sometimes LLMs forget the last brace '}' or add text after it.
        try:
            # Find the first '{' and the last '}'
            start = json_str.find("{")
            end = json_str.rfind("}")
            if start != -1 and end != -1:
                clean_str = json_str[start : end + 1]
                return json.loads(clean_str)
        except:
            pass
            
        print(f"Warning: Could not parse JSON: {json_str[:50]}...")
        return None
        
    def filter_step_1_and_2(self, raw_data):
        """
        Filter 1: Hallucination Check (RapidFuzz > 60)
        Filter 2: Ambiguity Check (Remove 'partially')
        """
        
        clean_candidates = []
        
        for entry in raw_data:
            feedback_text = entry["original_feedback"]
            raw_principles = self._parse_json_safely(entry["extracted_json"])
            
            if isinstance(feedback_text, list):
                feedback_text = max(feedback_text, key=len)
            
            if not raw_principles:
                continue
            
            valid_principles_for_this_entry = []
            
            for principle_name, value in raw_principles.items():
                if "-" not in value:
                    continue
                
                # Split evidence from label
                parts = value.rsplit("-", 1)
                evidence_span = parts[0].strip()
                label = parts[1].strip().lower()
                
                
                # --- FILTER 2: Ambiguity ---
                if "partially" in label:
                    continue # Skip partials
                
                # Normalize label to just Yes/No
                if "yes" in label: final_label = "Yes"
                elif "no" in label: final_label = "No"
                else: continue
                
                
                # --- FILTER 1: Hallucination ---
                # Check if evidence actually exists in text
                match_score = fuzz.partial_ratio(evidence_span.lower(), feedback_text.lower())
                if match_score <= 60: # Threshold from paper 
                    continue
                
                valid_principles_for_this_entry.append({
                    "principle": principle_name,
                    "label": final_label,
                    "evidence": evidence_span
                })
                
            if valid_principles_for_this_entry:
                clean_candidates.append({
                    "prompt_id": entry['prompt_id'],
                    "principles": valid_principles_for_this_entry
                })
                
                
        print(f"Step 1 & 2 Complete. {len(clean_candidates)} valid entries kept.")
        return clean_candidates
    
    def filter_step_3(self, candidates):
        """
        Filter 3: Consensus Check
        """
        from itertools import islice # Ensure this is imported if you ever use it again

        # Group by prompt_id
        grouped = defaultdict(list)
        for c in candidates:
            grouped[c['prompt_id']].append(c['principles'])
            
        final_dataset = []
        
        for p_id, annotator_lists in grouped.items(): 
            
            if len(annotator_lists) < 2:
                # If you trust single annotators, add them directly:
                # print(f"Prompt {p_id} has 1 annotator. Keeping without consensus check.")
                
                # Flatten the list structure since it's a list of lists
                single_annotator_principles = annotator_lists[0] 
                
                final_dataset.append({
                    "prompt_id": p_id,
                    "verified_principles": single_annotator_principles
                })
                continue 

            # (This runs only if there are 2+ annotators)
            for i, current_list in enumerate(annotator_lists):
                other_principles = []
                for j, other_list in enumerate(annotator_lists):
                    if i == j: continue
                    other_principles.extend([p['principle'] for p in other_list])
                
                if not other_principles: continue

                current_texts = [p['principle'] for p in current_list]
                curr_embeddings = self.encoder.encode(current_texts)
                other_embeddings = self.encoder.encode(other_principles)

                similarity_matrix = cosine_similarity(curr_embeddings, other_embeddings)
                
                valid_principles = []
                for k, p_obj in enumerate(current_list):
                    if similarity_matrix[k].max() > 0.8:
                        valid_principles.append(p_obj)
                
                if valid_principles:
                    final_dataset.append({
                        "prompt_id": p_id,
                        "verified_principles": valid_principles
                    })
                    
        print(f"Step 3 Complete. {len(final_dataset)} verified entries.")
        return final_dataset
    
    def run(self):
        # 1. Load Raw Data
        print("Loading raw data...")
        raw_data = []
        with open(self.input_file, "r") as f:
            for line in f:
                raw_data.append(json.loads(line))
    
        # 2. Run Filters 1 & 2
        candidates = self.filter_step_1_and_2(raw_data)
        
        # 3. Run Filter 3
        final_data = self.filter_step_3(candidates)
        
        with open(self.output_file, "w") as f:
            for entry in final_data:
                f.write(json.dumps(entry) + "\n")
        print(f"Done! Final dataset saved to {self.output_file}")

In [7]:
cleaner = DataCleaner(input_file="./data/extracted_principles.jsonl")

Loading Sentence Transformer model...


Fetching 4 files: 100%|██████████| 4/4 [16:15<00:00, 243.78s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [01:17<00:00, 19.41s/it]


In [8]:
cleaner.run()

Loading raw data...
  "Accuracy of Field Terminator": "The response ...
  "Correctness of Code": "However, the code seem...
Step 1 & 2 Complete. 34194 valid entries kept.
Step 3 Complete. 34194 verified entries.
Done! Final dataset saved to final_training_data.jsonl


In [88]:
data_path = "./data/extracted_principles.jsonl"

with open(data_path, "r") as f:
    raw_data = [json.loads(line) for line in f]

In [76]:
principles = [x['extracted_json'] for x in raw_data]