In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import scipy.stats as stats
import pickle
import os
import torch # Apenas para verifica√ß√£o de GPU, se dispon√≠vel

# --- 0. VERIFICA√á√ÉO DE GPU (Do seu snippet original) ---
print("--- Verifica√ß√£o de Hardware ---")
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("Total VRAM:", round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2), "GB")
else:
    print("CUDA not available. GPU not detected via Torch (TF usar√° o que estiver dispon√≠vel).")

# --- CORRE√á√ÉO DE COMPATIBILIDADE (TODOS OS COMPONENTES) ---
try:
    from tf_keras.optimizers import Adam
    from tf_keras.losses import SparseCategoricalCrossentropy
    from tf_keras.callbacks import EarlyStopping
    print("‚úÖ Usando tf_keras para compatibilidade total.")
except ImportError:
    try:
        from tensorflow.keras.optimizers.legacy import Adam
        from tensorflow.keras.losses import SparseCategoricalCrossentropy
        from tensorflow.keras.callbacks import EarlyStopping
        print("‚ö†Ô∏è tf_keras n√£o encontrado. Usando tensorflow.keras.optimizers.legacy.")
    except ImportError:
        from tensorflow.keras.optimizers import Adam
        from tensorflow.keras.losses import SparseCategoricalCrossentropy
        from tensorflow.keras.callbacks import EarlyStopping
        print("‚ö†Ô∏è Usando tensorflow.keras padr√£o.")

# --- CONFIGURA√á√ïES DO LEGALBERT ---
MODEL_NAME = "casehold/legalbert"
MAX_LENGTH = 256 # LegalBERT aguenta contexto maior, mantive 256 do seu snippet
BATCH_SIZE = 16
EPOCHS = 1000 
PATIENCE = 3  
LEARNING_RATE = 2e-5
CACHE_DIR = "./legalbert_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

# --- FUN√á√ÉO CI ---
def ci95(values):
    values = np.array(values)
    mean = np.mean(values)
    if len(values) <= 1: return np.nan, np.nan 
    se = stats.sem(values) 
    h = 1.96 * se
    return mean - h, mean + h

# --- 1. CARREGAMENTO DE DADOS (Igual ao RoBERTa CV) ---
if not os.path.exists('kfolds_resampled_indices.pkl'):
    print("‚ùå ERRO: Arquivo 'kfolds_resampled_indices.pkl' n√£o encontrado.")
else:
    with open('kfolds_resampled_indices.pkl', 'rb') as f:
        loaded_kfolds_indices = pickle.load(f)
    print("‚úÖ √çndices K-Fold carregados.")

# --- 2. PREPARA√á√ÉO DOS DADOS ---
# O script assume que os dados j√° est√£o carregados no ambiente ou no mesmo diret√≥rio
# Se estiver rodando do zero, precisa carregar o CSV resampled aqui
if 'X_train_resampled' not in locals() or 'y_train_resampled' not in locals():
    if os.path.exists('train_resampled_full.csv'):
        print("Carregando 'train_resampled_full.csv'...")
        df_train_resampled = pd.read_csv('train_resampled_full.csv')
        # Ajuste o nome da coluna de texto se necess√°rio (ex: 'content_corrected')
        col_text = 'content_corrected' if 'content_corrected' in df_train_resampled.columns else 'content'
        X_train_resampled = df_train_resampled[col_text]
        y_train_resampled = df_train_resampled['target']
    else:
        raise ValueError("ERRO: Vari√°veis X_train_resampled/y_train_resampled n√£o encontradas e CSV n√£o achado.")

texts = X_train_resampled.astype(str).values
labels = y_train_resampled.values

# --- 3. TOKENIZA√á√ÉO GLOBAL ---
print(f"--- Tokenizando dados para CV com {MODEL_NAME} ---")
# Usando AutoTokenizer para LegalBERT
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
encodings = tokenizer(
    texts.tolist(),
    max_length=MAX_LENGTH,
    padding='max_length', # Ou 'max_length' para garantir formato fixo no numpy
    truncation=True,
    return_tensors='tf'
)

input_ids_all = encodings['input_ids'].numpy()
attention_mask_all = encodings['attention_mask'].numpy()

# --- 4. LOOP DE VALIDA√á√ÉO CRUZADA ---
cv_metrics = {
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1': []
}

print(f"\n--- Iniciando Valida√ß√£o Cruzada ({len(loaded_kfolds_indices)} Folds) ---")

for fold, (train_idx, val_idx) in enumerate(loaded_kfolds_indices):
    print(f"\nüîÑ Training Fold {fold + 1}/{len(loaded_kfolds_indices)}...")
    
    # A. Limpar sess√£o
    tf.keras.backend.clear_session()
    
    # B. Separar dados
    X_train_ids = input_ids_all[train_idx]
    X_train_mask = attention_mask_all[train_idx]
    y_train_fold = labels[train_idx]
    
    X_val_ids = input_ids_all[val_idx]
    X_val_mask = attention_mask_all[val_idx]
    y_val_fold = labels[val_idx]
    
    # C. Datasets
    train_dataset = tf.data.Dataset.from_tensor_slices(
        ({'input_ids': X_train_ids, 'attention_mask': X_train_mask}, y_train_fold)
    ).shuffle(1000).batch(BATCH_SIZE)
    
    val_dataset = tf.data.Dataset.from_tensor_slices(
        ({'input_ids': X_val_ids, 'attention_mask': X_val_mask}, y_val_fold)
    ).batch(BATCH_SIZE)
    
    # D. Instanciar Modelo (AutoModel para LegalBERT)
    # Detectar n√∫mero de classes dinamicamente
    num_classes = len(np.unique(labels))
    model = TFAutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME, 
        num_labels=num_classes,
        cache_dir=CACHE_DIR
    )
    
    # Configurar Componentes
    optimizer_inst = Adam(learning_rate=LEARNING_RATE)
    loss_inst = SparseCategoricalCrossentropy(from_logits=True)
    
    model.compile(optimizer=optimizer_inst, loss=loss_inst, metrics=['accuracy'])
    
    # E. Treinar
    early_stop = EarlyStopping(
        monitor='val_accuracy', 
        mode='max', 
        patience=PATIENCE, 
        restore_best_weights=True,
        verbose=1
    )
    
    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=EPOCHS,
        callbacks=[early_stop],
        verbose=1
    )
    
    # F. Avaliar
    predictions = model.predict(val_dataset)
    y_pred_logits = predictions.logits
    y_pred = np.argmax(y_pred_logits, axis=1)
    
    acc = accuracy_score(y_val_fold, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_val_fold, y_pred, average='weighted', zero_division=0)
    
    cv_metrics['accuracy'].append(acc)
    cv_metrics['precision'].append(prec)
    cv_metrics['recall'].append(rec)
    cv_metrics['f1'].append(f1)
    
    print(f"‚úÖ Fold {fold+1} Result -> Acc: {acc:.4f}, F1: {f1:.4f}")

# --- 5. RESULTADOS FINAIS ---
print(f"\n--- üìä Resultados Consolidados ({MODEL_NAME}) ---")

results_summary = {}

for metric_name, values in cv_metrics.items():
    mean_val = np.mean(values)
    std_val = np.std(values)
    ci_lower, ci_upper = ci95(values)
    
    results_summary[metric_name] = {
        'Mean': mean_val,
        'Std': std_val,
        'CI_Lower': ci_lower,
        'CI_Upper': ci_upper
    }
    
    print(f"{metric_name.capitalize()}: {mean_val:.4f} ¬± {std_val:.4f} (95% CI: [{ci_lower:.4f}, {ci_upper:.4f}])")

df_results_cv = pd.DataFrame(results_summary).T
print("\nTabela de Resultados CV:")
print(df_results_cv)

# Salvar
df_results_cv.to_csv('resultados_validacao_cruzada_legalbert.csv')
print("\nResultados salvos em 'resultados_validacao_cruzada_legalbert.csv'")