# Revisar el score del 10% de los datos

## Librerías

In [None]:
import os
import json
import random
from typing import List
import time
import csv
import math

import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import shared_functions as custom_sharfun  #el archivo .py con funciones compartidas
import evaluation_metric as custom_metrics

from peft import PeftModel, LoraConfig
from datasets import load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from transformers import logging as hf_logging
hf_logging.set_verbosity_warning()

# Cargar el modelo entrenado

In [None]:
# Directorios y modelos
OUTPUT_DIR = "output/results/v10"
ADAPTER_DIR = os.path.join(OUTPUT_DIR, "modfinal_full")

# Configuración del dispositivo
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

In [None]:
pt_file = os.path.join(ADAPTER_DIR, "weights.pt")
pt_loaded = torch.load(pt_file, map_location=DEVICE, weights_only=False)

print("Cargando tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(pt_loaded["tokenizer"])
tokenizer.padding_side = "left" # A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Cargando modelo base en 4 bits...")
bnb_config = BitsAndBytesConfig(**pt_loaded["bnb_config"])

print("Modelo base:", pt_loaded["model_id"])
model_base = AutoModelForCausalLM.from_pretrained(
    pt_loaded["model_id"],
    quantization_config=bnb_config, 
    device_map="auto"
)

print("Cargando modelo LoRA...")
# Debes usar la misma configuración LoRA que usaste al entrenar
config_lora = LoraConfig(
    r = pt_loaded["config"]["lora_r"],
    lora_alpha = pt_loaded["config"]["lora_alpha"],
    lora_dropout = pt_loaded["config"]["lora_drop"],
    target_modules = pt_loaded["config"]["lora_target_mods"],
    bias=pt_loaded["config"]["lora_bias"],
    task_type=pt_loaded["config"]["lora_task_type"]
)

model = PeftModel(model_base, config_lora)
model.load_state_dict(pt_loaded["peft"], strict=False)

model.eval()

In [None]:
# Data / tokenization
MAX_LENGTH = 2048
# Evaluación
GEN_MAX_NEW_TOKENS = 512
BATCH_SIZE_EVAL = 14 # ajustar según memoria GPU para ejecutar más rápido la evaluación

# Semilla de entrenamiento
GLB_SEED = pt_loaded["config"]["seed"]
torch.manual_seed(GLB_SEED)
random.seed(GLB_SEED)
np.random.seed(GLB_SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(GLB_SEED)

# Cargando los datos de validación
Del 10% de los datos que fueron dividos en la etapa de entrenamiento.

In [None]:
OUTPUT_DIR_VALDATA = os.path.join(OUTPUT_DIR, "datavalidation")
val_list = load_from_disk(OUTPUT_DIR_VALDATA)

## Funciones adaptadas para ejecutar en lotes

padding="longest",       # respeta left-padding

In [None]:
def generate_json_raw_batch( texts: List[str], tokenizer, model, device, max_new_tokens: int, max_length: int, batch_size: int = 8):
    outputs = []
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    #eos_brace_id = tokenizer.encode("}", add_special_tokens=False)[0]

    for i in tqdm(range(0, len(texts), batch_size), desc="Generating", total=math.ceil(len(texts)/batch_size)):
        batch = texts[i:i + batch_size]
        prompts = [custom_sharfun.build_prompt(t) for t in batch]

        enc = tokenizer( prompts, return_tensors='pt', truncation=True, padding="longest", max_length=max_length).to(device)
        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        with torch.inference_mode():
            model.eval()
            out = model.generate( 
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                max_new_tokens=max_new_tokens, 
                do_sample=False, 
                pad_token_id=pad_id, 
                eos_token_id=tokenizer.eos_token_id, 
                use_cache=True
            )

        decoded = tokenizer.batch_decode(out, skip_special_tokens=True) # Decodificar outputs en lote

        # Recorte final
        cleaned = []
        for d in decoded:
            d = (d.replace("“", '"').replace("”", '"').replace("’", "'"))

            if "{" in d and "}" in d:
                first = d.find("{")
                last = d.rfind("}")
                d = d[first:last+1]
            cleaned.append(d)

        outputs.extend(cleaned)

    return outputs

Funcion mejorada para extraer el JSON

In [None]:
def extract_json_from_text(text: str):
    """
    Escanea llaves para encontrar un bloque JSON bien balanceado.
    """
    marker = '{"buyer":'
    pos = text.find(marker)
    if pos != -1:
        start = text.find(marker)
    else: # pos == -1
        marker = "\nJSON:\n"
        pos = text.find(marker)
        if pos == -1:
            start = text.find("{")
        else:
            start = text.find("{", pos + len(marker)) # Buscar la primera llave '{' después del marcador
            if start == -1:
                start = text.find("{")

    if start == -1:
        return None

    brace_count = 0
    in_json = False

    for i in range(start, len(text)):
        if text[i] == "{":
            brace_count += 1
            in_json = True
        elif text[i] == "}":
            brace_count -= 1

            # Si brace_count llega a 0 => JSON completo
            if in_json and brace_count == 0:
                candidate = text[start:i+1]

                # intentar parsear
                try:
                    return json.loads(candidate)
                except json.JSONDecodeError:
                    # intento reemplazando comillas simples
                    try:
                        return json.loads(candidate.replace("'", '"'))
                    except Exception:
                        return None

    return None

In [None]:
start_time = time.time()

print(f"Generando validación para evaluar F1-score: total datos = {len(val_list)} ...")

texts = [ex["natural_language"] for ex in val_list]
true_jsons = [ex["json_data"] for ex in val_list]
nat_langs = texts

pred_raw_list = generate_json_raw_batch( texts=texts, tokenizer=tokenizer, model=model, device=DEVICE, max_new_tokens=GEN_MAX_NEW_TOKENS,  max_length=MAX_LENGTH, batch_size=BATCH_SIZE_EVAL )
results = []
for nat_langs_save, raw, true_json in zip(nat_langs, pred_raw_list, true_jsons):
    pred_obj = extract_json_from_text(raw)
    if pred_obj is None:
        pred_obj = {}
    
    f1 = 0.0
    if pred_obj is None:
        f1 = 0.0
    else:
        try:
            f1 = custom_metrics.evaluate_json(true_json, json.dumps(pred_obj, ensure_ascii=False))
        except Exception:
            f1 = float(1.0 if pred_obj == true_json else 0.0)

    results.append({
        "nat_language": nat_langs_save,
        "raw_prediction": raw,
        "prediction": pred_obj,
        "true_json": true_json,
        "f1_score": f1
    })

end_time = time.time()
print( custom_sharfun.print_time_execution("Etapa validación de datos", start_time, end_time) )

In [None]:
 # Resumen de resultados en test: ver histrograma de F1 scores
f1_scores = [r['f1_score'] for r in results]
plt.figure(figsize=(8, 4))
plt.plot(range(1, len(f1_scores) + 1), f1_scores, marker='o')
plt.title("F1 Scores por Ejemplo de Validación")
plt.xlabel("Ejemplo de Validación")
plt.ylabel("F1 Score")
plt.grid(True)
plt.show()

In [None]:
# Guardar CSV - resultados de validación
OUTPUT_DIR_VAL = os.path.join(OUTPUT_DIR, "result_validation")
os.makedirs(OUTPUT_DIR_VAL, exist_ok=True)
csv_path = os.path.join(OUTPUT_DIR_VAL, 'validation_results.csv')
with open(csv_path, 'w', encoding='utf-8', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=['text','f1','raw','pred','true'], delimiter='|')
    writer.writeheader()
    for r in results:
        writer.writerow({
            'text': r['nat_language'],
            'f1': r['f1_score'],
            'raw': r['raw_prediction'],
            'pred': json.dumps(r['prediction'], ensure_ascii=False),
            'true': json.dumps(r['true_json'], ensure_ascii=False)
        })
print('CSV guardado en', csv_path)

In [None]:
# Histograma F1
f1_scores = [r['f1_score'] for r in results]
plt.figure()
plt.hist(f1_scores, bins=10)
plt.title('Distribución de F1')
plt.xlabel('F1')
plt.ylabel('Frecuencia')
plt.savefig(os.path.join(OUTPUT_DIR, 'f1_distribution.png'))
plt.show()
plt.close()
print('Histograma guardado en', os.path.join(OUTPUT_DIR, 'f1_distribution.png'))

In [None]:
# Mostrar peores 3 ejemplos
sorted_by_f1 = sorted(results, key=lambda x: x['f1_score'])
print('\nPeores 10 ejemplos:')
for r in sorted_by_f1[:10]:
    print(f"F1 Score: {r['f1_score']}")
    print('Texto:', r['raw'])
    print("*"*90)
    print('Pred:', r['pred'])
    print('True:', r['true'])
    print('-'*150)