In [None]:
import gc
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import re
import json
from collections import Counter, defaultdict
from rapidfuzz import fuzz
import textwrap
import os


In [None]:

# ==============================================================================
# --- 1. CORE CONFIGURATION ---
# ==============================================================================

# --- Model & Path Configuration ---
# NOTE: Replace with the actual paths to your fine-tuned model adapters
MODEL_PATHS = {
    "gemma_ft_1": "./lora_finetuned_model_1",
    "gemma_ft_2": "./lora_finetuned_model_2",
    "gemma_ft_3": "./lora_finetuned_model_3"
}
BASE_MODEL_NAME = "google/gemma-3-4b-it"

# --- Trustworthiness & Consensus Configuration ---
INITIAL_TRUST_SCORE = 1.0
MIN_TRUST_SCORE = 0.5
MAX_TRUST_SCORE = 2.0
CONFIDENT_SCORE_ADJUSTMENT = 0.1
PROBABLE_SCORE_ADJUSTMENT = 0.05

TRUST_TIERS = {
    "Trustworthy": 1.5,
    "Mid-tier": 0.8,
    "Untrustworthy": MIN_TRUST_SCORE
}
SIMILARITY_THRESHOLD = 85 # For clustering 'valor' text

# --- Master List of All Possible Fields and Classes ---
ALL_CLASSES = ['S', 'D', 'M', 'DS', 'PM', 'SV', 'R', 'P']
ALL_FIELDS = ['id', 'class', 'valor', 'estado_clinico', 'loc-temp', 'quem', 'relacionamento_id']

# --- Prompt Instructions for the Models ---
INSTRUCTIONS = textwrap.dedent("""
Você é um assistente especializado em Reconhecimento de Entidades Nomeadas (NER) na área médica.
Seu objetivo é identificar e extrair termos relevantes de um diário médico de cardiologia.
Formate a saída como uma lista JSON válida, onde cada objeto representa uma entidade identificada.

Instruções para o formato JSON:
1. A saída DEVE ser uma lista JSON `[...]`.
2. Cada objeto na lista deve ter as seguintes chaves:
    - `id`: (Inteiro) ID numérico sequencial da entidade (1, 2, 3...).
    - `classe`: (String) A classe da entidade (use os códigos: S, D, M, DS, PM, SV, R, P).
    - `valor`: (String) O texto exato da entidade identificada.
    - `estado_clinico`: (String) No caso de um diagnóstico ou sintoma, indica se a pessoa tem essa condição ("+"), se não tem ("-") ou se é uma possibilidade ("?"). Omita se não aplicável.
    - `loc-temp`: (String) O valor de localização temporal deve estar presente apenas caso não seja o momento atual (hoje) (ex: "ontem", "2014", "na infância", "há 3 dias"). Omita se não aplicável.
    - `quem`: (String) O valor de quem tem uma condição, diagnóstico, medicação, etc., presente apenas no caso de não se referir ao próprio paciente (ex: "familiar(filho, mãe, etc.)"). Omita se não aplicável.
    - `relacionamento_id`: (Inteiro) Inclua esta chave *apenas* quando 2 classes estão relacionadas, indicando o `id` da entidade relacionada (ex: o `id` do Medicamento para uma Dosagem). Omita esta chave noutros casos.
3. Certifique-se de que a saída final seja um JSON estritamente válido. Não inclua nenhum texto antes ou depois da lista JSON, nem marcadores como ```json.

Classes permitidas e seus significados:
    - S: Sintoma
    - D: Diagnóstico
    - M: Medicamento
    - DS: Dosagem
    - PM: Procedimento Médico
    - SV: Sinal Vital
    - R: Resultado (de exame, etc.)
    - P: Progresso (do paciente)

Exemplos de Saída JSON Válida:
[
  {"id": 1, "classe": "S", "valor": "Dor torácica", "estado_clinico": "+", "loc-temp": "2015", "quem": "mãe"},
  {"id": 2, "classe": "M", "valor": "Aspirina", "loc-temp": "ontem"},
  {"id": 3, "classe": "D", "valor": "Hipertensão", "estado_clinico": "-","quem": "familiar"},
  {"id": 4, "classe": "DS", "valor": "100", "relacionamento_id": 2},
  {"id": 5, "classe": "SV", "valor": "PA 120/80 mmHg"},
  {"id": 6, "classe": "PM", "valor": "Exame sangue"},
  {"id": 7, "classe": "R", "valor": "sangue normal", "relacionamento_id": 6}
]
""").strip()


In [None]:

# ==============================================================================
# --- 2. MODEL AND DATA LOADING ---
# ==============================================================================

def load_all_models(model_paths, base_model_name):
    """Loads all fine-tuned models from the specified paths."""
    models = {}
    for name, path in model_paths.items():
        print(f"Loading model '{name}' from {path}...")
        try:
            base_model = AutoModelForCausalLM.from_pretrained(base_model_name, attn_implementation="eager")
            model = PeftModel.from_pretrained(base_model, path)
            tokenizer = AutoTokenizer.from_pretrained(path)
            models[name] = (model, tokenizer)
        except Exception as e:
            print(f"  [ERROR] Could not load model '{name}'. Skipping. Error: {e}")
    return models

def load_gold_standard_data(file_path):
    """Loads the gold standard JSON data for the preparation phase."""
    print(f"Loading gold standard data from '{file_path}'...")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"  [ERROR] Gold standard file not found at '{file_path}'. Cannot run preparation phase.")
        return []
    except json.JSONDecodeError:
        print(f"  [ERROR] Gold standard file '{file_path}' is not valid JSON.")
        return []


In [None]:

# ==============================================================================
# --- 3. PREDICTION AND PREPARATION PHASE ---
# ==============================================================================

def get_predictions_from_model(text, model, tokenizer, instructions):
    """Generates structured JSON predictions from a single model."""
    prompt = f"{instructions}\n\nDiário Médico para Análise:\n---\n{text}\n---\n\nSaída JSON:"
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_new_tokens=1024, pad_token_id=tokenizer.eos_token_id)
        decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        json_str_match = re.search(r'\[.*\]', decoded_output, re.DOTALL)
        if json_str_match:
            json_str = json_str_match.group(0)
            return json.loads(json_str)
        return []
    except Exception as e:
        print(f"  [WARNING] Model prediction failed or produced invalid JSON. Error: {e}")
        return []

def run_preparation_phase(models, gold_standard_data, instructions):
    """Calculates initial, granular trust scores for all models based on a gold standard dataset."""
    print("\n🚀 Starting Preparation Phase...")
    
    correct_counts = {m: defaultdict(lambda: defaultdict(int)) for m in models}
    total_counts = defaultdict(lambda: defaultdict(int))

    for i, entry in enumerate(gold_standard_data, 1):
        print(f"  Processing gold standard entry {i}/{len(gold_standard_data)}...")
        text = entry["text"]
        gold_labels = entry["labels"]

        # Update total counts for each field in the gold standard
        for gold_label in gold_labels:
            for field, value in gold_label.items():
                total_counts[gold_label['class']][field] += 1

        for model_name, (model, tokenizer) in models.items():
            model_predictions = get_predictions_from_model(text, model, tokenizer, instructions)
            
            # Compare model predictions against gold labels field by field
            for gold_label in gold_labels:
                for model_pred in model_predictions:
                    if gold_label['class'] == model_pred['class'] and fuzz.ratio(gold_label['valor'], model_pred['valor']) > 95:
                        # Found a matching entity, now check each field
                        for field, gold_value in gold_label.items():
                            if str(model_pred.get(field)) == str(gold_value):
                                correct_counts[model_name][gold_label['class']][field] += 1
                        break

    # Calculate and store the initial trust scores
    initial_trust_scores = {m: defaultdict(dict) for m in models}
    for model_name in models:
        for class_name, fields in total_counts.items():
            for field, total in fields.items():
                accuracy = correct_counts[model_name][class_name][field] / total if total > 0 else 0
                # Scale accuracy to our score range and clamp it
                score = INITIAL_TRUST_SCORE + (accuracy - 0.5) * 2 
                initial_trust_scores[model_name][class_name][field] = max(MIN_TRUST_SCORE, min(score, MAX_TRUST_SCORE))

    print("✅ Preparation Phase Complete.")
    return initial_trust_scores


In [None]:


# ==============================================================================
# --- 4. ADVANCED CONSENSUS LOGIC ---
# ==============================================================================

def get_field_consensus(cluster, class_name, field_name, trust_scores):
    """Determines the consensus value for a single field within a cluster."""
    
    def get_trusty_prediction():
        """Finds the prediction from the most trustworthy model for this specific task."""
        best_model = max(cluster, key=lambda p: trust_scores.get(p['model'], {}).get(class_name, {}).get(field_name, INITIAL_TRUST_SCORE))
        return best_model.get(field_name, "OMIT")

    votes = Counter(p.get(field_name, "OMIT") for p in cluster)
    most_common = votes.most_common(2)

    if len(most_common) == 1 or (len(most_common) > 1 and most_common[0][1] > most_common[1][1]):
        consensus = most_common[0][0]
    else: # Tie-breaker
        consensus = get_trusty_prediction()
        
    return consensus if consensus != "OMIT" else None

def get_tiered_consensus(text, models, instructions, trust_scores):
    """Orchestrates the entire consensus process for a single piece of text."""
    
    # 1. Get Predictions from all models
    all_predictions = []
    for model_name, (model, tokenizer) in models.items():
        preds = get_predictions_from_model(text, model, tokenizer, instructions)
        for p in preds:
            p['model'] = model_name # Tag prediction with its source model
        all_predictions.extend(preds)

    # 2. Cluster predictions by class and valor similarity
    clusters = []
    for pred in all_predictions:
        found_cluster = False
        for cluster in clusters:
            if pred['class'] == cluster[0]['class'] and fuzz.ratio(pred['valor'], cluster[0]['valor']) > SIMILARITY_THRESHOLD:
                cluster.append(pred)
                found_cluster = True
                break
        if not found_cluster:
            clusters.append([pred])

    # 3. Build consensus for each cluster
    final_entities = []
    for cluster in clusters:
        class_name = cluster[0]['class']
        
        # Build consensus for each field
        consensus_obj = {"class": class_name}
        for field in ALL_FIELDS:
            if field in ['id', 'class']: continue # Skip these
            
            consensus_value = get_field_consensus(cluster, class_name, field, trust_scores)
            if consensus_value is not None:
                consensus_obj[field] = consensus_value
        
        # 4. Determine Confidence Tier
        agreeing_models = {p['model'] for p in cluster if p['valor'] == consensus_obj['valor']}
        disagreeing_model_count = len(models) - len(agreeing_models)
        
        confidence_tier = "Needs Human Review" # Default
        if disagreeing_model_count == 0:
            confidence_tier = "Confident"
        elif disagreeing_model_count == 1:
            confidence_tier = "Probable"
            
        consensus_obj["confidence"] = confidence_tier
        final_entities.append(consensus_obj)

        # 5. Update Trust Scores (Conditional & Tiered)
        if confidence_tier in ["Confident", "Probable"]:
            increment = CONFIDENT_SCORE_ADJUSTMENT if confidence_tier == "Confident" else PROBABLE_SCORE_ADJUSTMENT
            
            for model_name in models:
                model_pred = next((p for p in cluster if p['model'] == model_name), None)
                
                for field in ALL_FIELDS:
                    if field == 'id': continue

                    current_score = trust_scores[model_name].get(class_name, {}).get(field, INITIAL_TRUST_SCORE)
                    
                    if model_pred and str(model_pred.get(field)) == str(consensus_obj.get(field)):
                        new_score = current_score + increment
                    else:
                        new_score = current_score - increment
                    
                    # Clamp the score
                    trust_scores[model_name][field] = max(MIN_TRUST_SCORE, min(new_score, MAX_TRUST_SCORE))

    # Assign final sequential IDs
    for i, entity in enumerate(final_entities, 1):
        entity['id'] = i
        
    return final_entities

In [None]:


# ==============================================================================
# --- 5. MAIN EXECUTION SCRIPT ---
# ==============================================================================

def main():
    """Main function to run the entire pipeline."""
    
    # --- Load Models ---
    loaded_models = load_all_models(MODEL_PATHS, BASE_MODEL_NAME)
    if not loaded_models:
        print("No models were loaded. Exiting.")
        return

    # --- Run Preparation Phase ---
    gold_data = load_gold_standard_data("gold_standard_diaries.json")
    if gold_data:
        trust_scores = run_preparation_phase(loaded_models, gold_data, INSTRUCTIONS)
        print("\n--- Initial Trust Scores ---")
        print(json.dumps(trust_scores, indent=2))
    else:
        # Initialize with default scores if no gold data
        trust_scores = {m: {c: {f: INITIAL_TRUST_SCORE for f in ALL_FIELDS} for c in ALL_CLASSES} for m in loaded_models}
        print("\n--- Initializing with default trust scores ---")


    # --- Process New Unlabeled Diaries ---
    unlabeled_file = "diarios-cardiologia-amostra.txt" # or "diarios-psicologia-amostra.txt"
    print(f"\n\n🚀 Starting Labeling Phase for '{unlabeled_file}'...")
    
    try:
        with open(unlabeled_file, 'r', encoding='utf-8') as f:
            full_text = f.read()
            # This processes the entire file as one document. 
            # You can adapt this to process entry by entry as before.
            final_consensus_labels = get_tiered_consensus(full_text, loaded_models, INSTRUCTIONS, trust_scores)
            
            print("\n\n✅ --- FINAL CONSENSUS LABELS --- ✅")
            print(json.dumps(final_consensus_labels, indent=2, ensure_ascii=False))
            
            print("\n\n--- Final Trust Scores ---")
            print(json.dumps(trust_scores, indent=2))

    except FileNotFoundError:
        print(f"  [ERROR] Unlabeled diary file not found at: '{unlabeled_file}'")

    # --- Cleanup ---
    print("\nCleaning up resources...")
    del loaded_models
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("Script finished.")


if __name__ == "__main__":
    # NOTE: To run this script, you must have:
    # 1. Your fine-tuned model directories at the paths specified in MODEL_PATHS.
    # 2. A 'gold_standard_diaries.json' file for the preparation phase.
    # 3. The unlabeled diary file you wish to process.
    main()

