
# Brain MRI Pipelines — notebook executável (sem Tkinter)

Este notebook reorganiza o código do app em Tkinter em um fluxo que pode ser executado diretamente: segmentação dos ventrículos, criação do dataset combinado e treinos (SVM, XGBoost e DenseNet). As saídas são gravadas em `output/` e os passos podem ser executados de forma independente.



## Fluxo sugerido
1) Configurar caminhos e dependências.  
2) (Opcional) Processar/segmentar as imagens NIfTI em lote.  
3) Gerar o dataset `exam_level_dataset_split.csv` combinando descritores com o CSV demográfico.  
4) Treinar modelos clássicos (SVM / XGBoost).  
5) Treinar DenseNet (classificação ou regressão).  

Todas as funções aqui dentro evitam qualquer chamada ao Tkinter e usam apenas código headless.


In [None]:
# Notebook auto-notes: keep split consistent; original_path rebuilt in code.
# Configura caminhos base e valida se dataset/output e CSVs estão presentes.

from pathlib import Path
from datetime import datetime

BASE_DIR = Path.cwd()
DATASET_DIR = BASE_DIR / "axl"
OUTPUT_DIR = BASE_DIR / "output"
NOT_VIABLE_DIR = BASE_DIR / "not_viable"
CSV_DEMOGRAPHIC = BASE_DIR / "oasis_longitudinal_demographic.csv"
DESCRIPTORS_CSV = OUTPUT_DIR / "ventricle_descriptors.csv"
EXAM_SPLIT_CSV = OUTPUT_DIR / "exam_level_dataset_split.csv"
HISTORY_JSON = OUTPUT_DIR / "training_experiments.json"

for p in (OUTPUT_DIR, NOT_VIABLE_DIR):
    p.mkdir(exist_ok=True)

print(f"Base: {BASE_DIR}")
print(f"Dataset: {DATASET_DIR.exists()} | Arquivos NIfTI: {len(list(DATASET_DIR.glob('*.nii*')))}")
print(f"CSV demográfico: {CSV_DEMOGRAPHIC.exists()}")
print(f"Descritores existentes: {DESCRIPTORS_CSV.exists()}")
print(f"Split existente: {EXAM_SPLIT_CSV.exists()}")


In [None]:

from brain_mri.utils.dependencies import ensure_dependencies

missing = ensure_dependencies(BASE_DIR / "requirements.txt")
if missing:
    print("Ainda faltam pacotes para instalar manualmente:", ", ".join(missing))
else:
    print("Dependências principais atendidas.")


In [None]:

import pandas as pd

nii_files = sorted(DATASET_DIR.glob('*.nii*'))
print(f"Total de exames NIfTI encontrados: {len(nii_files)}")
if nii_files:
    print("Exemplos:")
    for f in nii_files[:3]:
        print(" -", f.name)

if DESCRIPTORS_CSV.exists():
    df_desc = pd.read_csv(DESCRIPTORS_CSV)
    print(f"Descritores: {len(df_desc)} linhas, colunas: {list(df_desc.columns)[:8]} ...")
if EXAM_SPLIT_CSV.exists():
    df_split = pd.read_csv(EXAM_SPLIT_CSV)
    display(df_split.head())


## Funções utilitárias (sem Tkinter)

In [None]:
# Notebook auto-notes: keep split consistent; original_path rebuilt in code.
# Utilidades compartilhadas: bootstrap de dependências, helpers de logging/plot,
# resolução de caminhos originais e cálculo longitudinal para descritores.

import json
import re
import shutil
import time
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

try:
    import nibabel as nib
except ImportError:
    nib = None
from PIL import Image

from brain_mri.utils.image_utils import ImageUtils
from brain_mri.ml.training_utils import (
    ExponentialMovingAverage,
    build_densenet,
    build_transforms,
    focal_loss,
    mixup_data,
    select_device,
)
from brain_mri.ml.datasets import MRIDataset

from sklearn.model_selection import train_test_split, GridSearchCV, GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    mean_absolute_error,
    mean_squared_error,
    precision_score,
    r2_score,
    recall_score,
)
from sklearn.svm import SVC
import xgboost as xgb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

# Evita explosão de threads em CPUs/ARM (mesma lógica do app original)
import os
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

def save_experiment(data, path=HISTORY_JSON):
    payload = dict(data)
    payload["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    history = []
    if path.exists():
        try:
            history = json.loads(path.read_text())
        except Exception:
            pass
    history.append(payload)
    path.write_text(json.dumps(history, indent=2))
    return path

def plot_confusion_matrix(ax, cm, classes, title):
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=classes, yticklabels=classes, cbar=False)
    ax.set_title(title, fontsize=10, fontweight='bold')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=8)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=8)

def resolve_original_path(mri_id: str) -> str:
    for ext in (".nii.gz", ".nii"):
        p = DATASET_DIR / f"{mri_id}_axl{ext}"
        if p.exists():
            try:
                return str(p.relative_to(OUTPUT_DIR.parent)).replace("\\", "/")
            except ValueError:
                return str(p)
    return ""

def calc_longitudinal(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    if 'viable' not in df.columns:
        df['viable'] = True

    required_cols = [
        'area_change', 'area_change_percent', 'perimeter_change', 'circularity_change',
        'eccentricity_change', 'solidity_change', 'major_axis_change',
        'minor_axis_change', 'visit_number'
    ]
    for col in required_cols:
        if col not in df.columns:
            df[col] = np.nan

    for idx, row in df.iterrows():
        m = re.search(r'MR(\d+)', str(row.get('MRI_ID', '')))
        if m:
            df.at[idx, 'visit_number'] = int(m.group(1))

    change_map = {
        'ventricle_area': 'area_change',
        'ventricle_perimeter': 'perimeter_change',
        'ventricle_circularity': 'circularity_change',
        'ventricle_eccentricity': 'eccentricity_change',
        'ventricle_solidity': 'solidity_change',
        'ventricle_major_axis_length': 'major_axis_change',
        'ventricle_minor_axis_length': 'minor_axis_change'
    }

    if 'Subject_ID' not in df.columns:
        return df

    for sid in df['Subject_ID'].dropna().unique():
        subj_df = df[df['Subject_ID'] == sid].sort_values('visit_number')
        prev_idx = None
        for idx in subj_df.index:
            if prev_idx is None:
                prev_idx = idx
                continue
            if not (bool(df.at[idx, 'viable']) and bool(df.at[prev_idx, 'viable'])):
                prev_idx = idx
                continue
            for src_col, dst_col in change_map.items():
                if src_col in df.columns:
                    prev_val = df.at[prev_idx, src_col]
                    cur_val = df.at[idx, src_col]
                    if pd.notna(prev_val) and pd.notna(cur_val):
                        df.at[idx, dst_col] = cur_val - prev_val
            prev_area = df.at[prev_idx, 'ventricle_area'] if 'ventricle_area' in df.columns else np.nan
            cur_area = df.at[idx, 'ventricle_area'] if 'ventricle_area' in df.columns else np.nan
            if pd.notna(prev_area) and pd.notna(cur_area) and prev_area:
                df.at[idx, 'area_change_percent'] = ((cur_area - prev_area) / prev_area) * 100
            prev_idx = idx
    return df

def update_descriptors_csv(mri_id: str, descriptors: dict, seg_path: str):
    mri_clean = mri_id.replace('_axl', '')
    subj_id = mri_clean.split('_MR')[0] if '_MR' in mri_clean else mri_clean
    data = {
        'MRI_ID': mri_clean,
        'Subject_ID': subj_id,
        'viable': True,
        'segmented_path': seg_path,
        'ventricle_area': descriptors['area'],
        'ventricle_perimeter': descriptors['perimeter'],
        'ventricle_circularity': descriptors['circularity'],
        'ventricle_eccentricity': descriptors['eccentricity'],
        'ventricle_solidity': descriptors['solidity'],
        'ventricle_major_axis_length': descriptors['major_axis_length'],
        'ventricle_minor_axis_length': descriptors['minor_axis_length'],
    }

    df = pd.read_csv(DESCRIPTORS_CSV) if DESCRIPTORS_CSV.exists() else pd.DataFrame()
    if df.empty and not DESCRIPTORS_CSV.exists():
        df = pd.DataFrame(columns=data.keys())
    if 'viable' not in df.columns:
        df['viable'] = True
    for key in data.keys():
        if key not in df.columns:
            df[key] = np.nan

    if 'MRI_ID' in df.columns and mri_clean in df['MRI_ID'].values:
        for k, v in data.items():
            df.loc[df['MRI_ID'] == mri_clean, k] = v
    else:
        df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)

    df = calc_longitudinal(df)
    df.to_csv(DESCRIPTORS_CSV, index=False)
    return df



### Segmentação headless

In [None]:
# Notebook auto-notes: keep split consistent; original_path rebuilt in code.
# Funções de segmentação headless (carrega fatia axial, segmenta, salva PNG e descritores).

def load_axial_slice(path: Path):
    if nib is None:
        raise ImportError("Instale nibabel para ler arquivos NIfTI (pip install nibabel).")
    data = np.squeeze(nib.load(str(path)).get_fdata())
    if data.ndim == 3:
        data = data[:, :, 0]
    return ImageUtils.normalize_array(data)

def segment_image(path: Path):
    img = load_axial_slice(path)
    mask, _, _, _ = ImageUtils.grow_region(img)
    fname = path.stem.replace('.nii', '')
    out_png = OUTPUT_DIR / f"{fname}_segmented.png"
    Image.fromarray((mask * 255).astype(np.uint8)).save(out_png)
    descriptors = ImageUtils.calculate_descriptors(mask)
    update_descriptors_csv(fname, descriptors, f"output/{out_png.name}")
    return {
        'image': path,
        'png': out_png,
        'descriptors': descriptors,
    }

def segment_all_images(overwrite=False, limit=None, paths=None):
    targets = sorted(paths or DATASET_DIR.glob('*.nii*'))
    if limit:
        targets = targets[:limit]
    results = []
    for i, p in enumerate(targets, 1):
        out_png = OUTPUT_DIR / f"{p.stem.replace('.nii', '')}_segmented.png"
        if out_png.exists() and not overwrite:
            continue
        res = segment_image(p)
        results.append(res)
        if i % 10 == 0:
            print(f"Processados {i}/{len(targets)} arquivos")
    print(f"Concluído: {len(results)} novas segmentações salvas em {OUTPUT_DIR}")
    return results


### Criar dataset combinado (descritores + demografia)

In [None]:
# Notebook auto-notes: keep split consistent; original_path rebuilt in code.
# Constrói dataset combinado (descritores + demografia) e faz split por sujeito.

def create_exam_level_dataset():
    if not DESCRIPTORS_CSV.exists():
        raise FileNotFoundError("Gere descritores primeiro (segmentação).")
    df_desc = pd.read_csv(DESCRIPTORS_CSV)
    if df_desc.empty:
        raise ValueError("CSV de descritores está vazio.")
    if 'viable' not in df_desc.columns:
        df_desc['viable'] = True

    df_demo = pd.read_csv(CSV_DEMOGRAPHIC, sep=';', decimal=',')
    df_demo.columns = [c.strip() for c in df_demo.columns]
    if 'MRI ID' in df_demo.columns:
        df_demo.rename(columns={'MRI ID': 'MRI_ID'}, inplace=True)
    if 'Subject ID' in df_demo.columns:
        df_demo.rename(columns={'Subject ID': 'Subject_ID'}, inplace=True)

    def _as_numeric(series):
        return pd.to_numeric(series.astype(str).str.replace(',', '.').str.strip(), errors='coerce')

    numeric_map = {
        'Age': 'age',
        'EDUC': 'education',
        'MMSE': 'mmse',
        'CDR': 'cdr',
        'eTIV': 'etiv',
        'nWBV': 'nwbv',
        'ASF': 'asf'
    }
    for src, dst in numeric_map.items():
        if src in df_demo.columns:
            df_demo[dst] = _as_numeric(df_demo[src])

    if 'M/F' in df_demo.columns:
        df_demo['sex'] = df_demo['M/F'].map({'M': 0, 'F': 1})

    merged = pd.merge(df_desc, df_demo, on='MRI_ID', how='left', suffixes=('', '_demo'))
    merged['viable'] = merged['viable'].fillna(True)
    merged = merged[merged['viable'] == True]

    if 'Subject_ID_x' in merged.columns:
        merged['Subject_ID'] = merged['Subject_ID_x']
    if 'Subject_ID_y' in merged.columns:
        merged['Subject_ID'] = merged['Subject_ID'].fillna(merged['Subject_ID_y'])
        merged.drop(columns=['Subject_ID_y'], inplace=True)
    if 'Subject_ID_x' in merged.columns:
        merged.drop(columns=['Subject_ID_x'], inplace=True)

    merged['Original_Group'] = merged.get('Group')

    def _resolve_final_group(row):
        grp = row.get('Group')
        if isinstance(grp, str) and grp == 'Converted':
            cdr_val = row.get('cdr') if 'cdr' in row else row.get('CDR')
            if pd.notna(cdr_val) and float(cdr_val) > 0:
                return 'Demented'
            return 'Nondemented'
        return grp

    merged['Final_Group'] = merged.apply(_resolve_final_group, axis=1)
    merged['Final_Group'] = merged['Final_Group'].fillna(merged['Original_Group'])

    merged['original_path'] = merged['MRI_ID'].apply(resolve_original_path)
    merged = merged[merged['original_path'] != ""]
    merged = merged[merged['Subject_ID'].notna()]

    subjects = merged['Subject_ID'].unique()
    if len(subjects) < 3:
        raise ValueError("Dados insuficientes para split (mínimo 3 sujeitos).")

    train_sub, test_sub = train_test_split(subjects, test_size=0.2)
    train_sub, val_sub = train_test_split(train_sub, test_size=0.2)

    def get_split(sid):
        if sid in val_sub: return 'validation'
        if sid in test_sub: return 'test'
        return 'train'

    merged['split'] = merged['Subject_ID'].apply(get_split)

    cols_to_drop = ['Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF', 'Visit', 'MR Delay', 'M/F']
    cols_to_drop = [c for c in cols_to_drop if c in merged.columns]
    if cols_to_drop:
        merged.drop(columns=cols_to_drop, inplace=True)

    merged.to_csv(EXAM_SPLIT_CSV, index=False)
    print(f"Dataset salvo em {EXAM_SPLIT_CSV} | Total: {len(merged)} exames")
    return merged


### Treinos clássicos: SVM (classificação) e XGBoost (regressão de idade)

In [None]:
# Notebook auto-notes: keep split consistent; original_path rebuilt in code.
# Treinos clássicos (SVM classificação e XGBoost regressão) com grids e salvamento de métricas.

DEFAULT_SVM_FEATURES = [
    'ventricle_area', 'ventricle_perimeter', 'ventricle_circularity',
    'ventricle_eccentricity', 'mmse', 'cdr', 'age'
]

DEFAULT_XGB_FEATURES = [
    'ventricle_area', 'ventricle_perimeter', 'ventricle_circularity',
    'ventricle_eccentricity', 'mmse', 'cdr', 'nwbv', 'etiv', 'asf', 'sex', 'education'
]

def train_svm_classifier(features=None):
    if not EXAM_SPLIT_CSV.exists():
        raise FileNotFoundError("Crie o dataset combinado primeiro.")
    df = pd.read_csv(EXAM_SPLIT_CSV)
    features = features or DEFAULT_SVM_FEATURES

    tmp = df.copy()
    if 'sex' in features and 'sex' not in tmp.columns:
        if 'M/F' in tmp.columns:
            tmp['sex'] = tmp['M/F'].map({'M': 0, 'F': 1})
        else:
            tmp['sex'] = np.nan

    missing = [f for f in features if f not in tmp.columns]
    if missing:
        raise ValueError(f"Colunas ausentes no dataset: {missing}")

    X = tmp[features].copy().fillna(tmp[features].mean()).values
    y = (tmp['Final_Group'] == 'Demented').astype(int).values

    train_mask = df['split'] == 'train'
    val_mask = df['split'] == 'validation'
    test_mask = df['split'] == 'test'
    if not val_mask.any():
        raise ValueError("Split de validação vazio.")

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X[train_mask])
    X_val = scaler.transform(X[val_mask])
    X_test = scaler.transform(X[test_mask]) if test_mask.any() else None

    grid = {
        'C': [0.1, 1, 10, 100],
        'gamma': ['scale', 'auto', 0.001, 0.01, 0.1],
        'kernel': ['rbf', 'linear']
    }
    gs = GridSearchCV(SVC(), grid, cv=3, scoring='accuracy', n_jobs=-1, verbose=1)
    gs.fit(X_train, y[train_mask])
    clf = gs.best_estimator_

    y_train_pred = clf.predict(X_train)
    y_val_pred = clf.predict(X_val)
    acc_tr = accuracy_score(y[train_mask], y_train_pred)
    acc_val = accuracy_score(y[val_mask], y_val_pred)

    test_cm = None
    metrics = {
        'train_accuracy': float(acc_tr),
        'val_accuracy': float(acc_val),
        'best_params': gs.best_params_,
    }
    if X_test is not None:
        y_test_pred = clf.predict(X_test)
        acc_test = accuracy_score(y[test_mask], y_test_pred)
        test_cm = confusion_matrix(y[test_mask], y_test_pred)
        metrics.update({
            'test_accuracy': float(acc_test),
            'test_precision': float(precision_score(y[test_mask], y_test_pred, zero_division=0)),
            'test_recall': float(recall_score(y[test_mask], y_test_pred, zero_division=0)),
            'test_f1': float(f1_score(y[test_mask], y_test_pred, zero_division=0)),
        })

    if test_cm is not None:
        fig_cm, ax = plt.subplots(figsize=(6, 5))
        plot_confusion_matrix(ax, test_cm, ['0', '1'], "Teste")
        fig_cm.tight_layout()
        cm_path = OUTPUT_DIR / "confusion_svm.png"
        fig_cm.savefig(cm_path, dpi=300, bbox_inches='tight')
        plt.close(fig_cm)
        metrics['confusion_matrix'] = cm_path.name

    import pickle
    with open(OUTPUT_DIR / "svm_scaler.pkl", "wb") as f:
        pickle.dump(scaler, f)
    with open(OUTPUT_DIR / "svm_model.pkl", "wb") as f:
        pickle.dump(clf, f)

    metrics['model'] = 'SVM'
    metrics['features'] = features
    save_experiment(metrics)
    print("SVM treinado.", metrics)
    return metrics

def train_xgboost_regressor(features=None):
    if not EXAM_SPLIT_CSV.exists():
        raise FileNotFoundError("Crie o dataset combinado primeiro.")
    df = pd.read_csv(EXAM_SPLIT_CSV)
    features = features or DEFAULT_XGB_FEATURES

    tmp = df.copy()
    if 'sex' in features and 'sex' not in tmp.columns and 'M/F' in tmp.columns:
        tmp['sex'] = tmp['M/F'].map({'M': 0, 'F': 1})

    missing = [f for f in features if f not in tmp.columns]
    if missing:
        raise ValueError(f"Colunas ausentes no dataset: {missing}")

    X = tmp[features].fillna(tmp[features].mean()).values
    y = tmp['age'].values

    train_mask = df['split'] == 'train'
    val_mask = df['split'] == 'validation'
    if not val_mask.any():
        raise ValueError("Split de validação vazio.")

    groups = df.loc[train_mask, 'Subject_ID']
    base = xgb.XGBRegressor(objective='reg:squarederror', tree_method='hist', n_jobs=1, verbosity=0)
    grid = {
        'n_estimators': [200, 300, 500],
        'max_depth': [6, 8, 10],
        'learning_rate': [0.05, 0.1, 0.15],
        'min_child_weight': [1, 3, 5],
        'subsample': [0.8, 0.9],
        'colsample_bytree': [0.8, 0.9]
    }

    gkf = GroupKFold(n_splits=3)
    gs = GridSearchCV(base, grid, cv=gkf.split(X[train_mask], y[train_mask], groups),
                      scoring='neg_mean_absolute_error', n_jobs=-1, verbose=1)
    gs.fit(X[train_mask], y[train_mask])
    model = gs.best_estimator_

    val_preds = model.predict(X[val_mask])
    mae_val = mean_absolute_error(y[val_mask], val_preds)
    mse_val = mean_squared_error(y[val_mask], val_preds)
    rmse_val = float(np.sqrt(mse_val))
    r2_val = r2_score(y[val_mask], val_preds)

    import pickle
    with open(OUTPUT_DIR / "xgb_age.pkl", "wb") as f:
        pickle.dump(model, f)

    metrics = {
        'model': 'XGBoost',
        'target': 'age',
        'features': features,
        'val_mae': float(mae_val),
        'val_mse': float(mse_val),
        'val_rmse': float(rmse_val),
        'val_r2': float(r2_val),
        'best_params': gs.best_params_,
    }
    save_experiment(metrics)
    print("XGBoost treinado.", metrics)
    return metrics


### Treino DenseNet (classificação ou regressão)

In [None]:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if hasattr(torch.backends, "cudnn"):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
# Notebook auto-notes: keep split consistent; original_path rebuilt in code.
# Treino DenseNet (classificação/regressão) com mixup/focal/EMA, curvas, métricas e export de embeddings.
import time

def train_densenet(mode='classification', hparams=None, export_embeddings=True, max_epochs=None):
    set_seed(42)
    if not EXAM_SPLIT_CSV.exists():
        raise FileNotFoundError("Crie o dataset combinado primeiro.")
    df = pd.read_csv(EXAM_SPLIT_CSV)
    device = select_device()
    print(f"Dispositivo: {device} | Torch threads: {torch.get_num_threads()}")
    start_time = time.time()

    defaults = {
        "lr": 1e-4 if mode == 'classification' else 0.001,
        "weight_decay": 1e-4 if mode == 'classification' else 0.0,
        "dropout": 0.3,
        "label_smoothing": 0.05 if mode == 'classification' else 0.0,
        "mixup_alpha": 0.4 if mode == 'classification' else 0.0,
        "freeze_backbone": False,
        "freeze_warmup_epochs": 0,
        "warmup_lr": None,
        "class_balance": False,
        "balance_penalty": 0.25,
        "thresholds_eval": [0.5, 0.6, 0.4, 0.7],
    }
    if hparams:
        for k, v in hparams.items():
            if k in defaults:
                defaults[k] = v

    lr = defaults["lr"]
    weight_decay = defaults["weight_decay"]
    dropout_rate = defaults["dropout"]
    label_smoothing = defaults["label_smoothing"]
    mixup_alpha = defaults["mixup_alpha"]
    freeze_backbone = bool(defaults["freeze_backbone"])
    freeze_warmup_epochs = int(defaults.get("freeze_warmup_epochs", 0) or 0)
    warmup_lr = defaults.get("warmup_lr", None)
    use_class_balance = bool(defaults["class_balance"])
    balance_penalty = defaults.get("balance_penalty", 0.0)
    thresholds_eval = defaults.get("thresholds_eval", [0.5])

    age_scaler = None
    if mode == 'regression':
        age_scaler = StandardScaler()
        df_train = df[df['split']=='train'].copy()
        df_val = df[df['split']=='validation'].copy()
        df_test = df[df['split']=='test'].copy()
        df_train['age_normalized'] = age_scaler.fit_transform(df_train[['age']])
        df_val['age_normalized'] = age_scaler.transform(df_val[['age']])
        df_test['age_normalized'] = age_scaler.transform(df_test[['age']])
        df.loc[df['split']=='train', 'age_normalized'] = df_train['age_normalized']
        df.loc[df['split']=='validation', 'age_normalized'] = df_val['age_normalized']
        df.loc[df['split']=='test', 'age_normalized'] = df_test['age_normalized']
        print(f"Idade normalizada: [{df_train['age_normalized'].min():.2f}, {df_train['age_normalized'].max():.2f}]")

    train_tf, val_tf = build_transforms()
    lbl_col = 'age_normalized' if mode == 'regression' else 'Final_Group'
    train_ds = MRIDataset(df[df['split']=='train'], train_tf, DATASET_DIR.parent, 'original_path', lbl_col)
    val_ds = MRIDataset(df[df['split']=='validation'], val_tf, DATASET_DIR.parent, 'original_path', lbl_col)
    test_ds = MRIDataset(df[df['split']=='test'], val_tf, DATASET_DIR.parent, 'original_path', lbl_col)
    if len(val_ds) == 0:
        raise ValueError("Split de validação vazio.")

    epochs = max_epochs or (40 if mode == 'classification' else 20)
    batch_size = 16
    early_stop_patience = 7 if mode == 'classification' else None
    use_mixup = mode == 'classification' and mixup_alpha > 0
    use_focal = mode == 'classification' and not use_mixup
    focal_gamma = 1.5
    use_ema = mode == 'classification'
    ema_decay = 0.999

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    model = build_densenet(mode=mode, dropout_rate=dropout_rate).to(device)
    warmup_epochs_remaining = freeze_warmup_epochs
    if warmup_epochs_remaining > 0 and hasattr(model, "features"):
        for p in model.features.parameters():
            p.requires_grad = False
        current_lr = warmup_lr if warmup_lr is not None else lr * 0.5
    else:
        current_lr = lr
    if freeze_backbone and warmup_epochs_remaining == 0 and hasattr(model, "features"):
        for p in model.features.parameters():
            p.requires_grad = False

    optimizer = optim.Adam(model.parameters(), lr=current_lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr * 0.1)
    class_weights = None
    if mode == 'classification' and use_class_balance and 'Final_Group' in df.columns:
        counts = df[df['split'] == 'train']['Final_Group'].value_counts()
        if len(counts) >= 1:
            total = counts.sum()
            w0 = total / (2 * counts.get('Nondemented', max(counts.max(), 1)))
            w1 = total / (2 * counts.get('Demented', max(counts.max(), 1)))
            class_weights = torch.tensor([w0, w1], dtype=torch.float32, device=device)

    criterion = nn.MSELoss() if mode == 'regression' else nn.CrossEntropyLoss(
        label_smoothing=label_smoothing,
        weight=class_weights
    )
    ema = ExponentialMovingAverage(model, decay=ema_decay) if use_ema else None

    history_train_loss, history_val_loss = [], []
    history_train_acc, history_val_acc = [], []
    history_train_mae, history_val_mae = [], []
    best_state, best_epoch = None, 0
    best_val_metric = -float('inf') if mode == 'classification' else float('inf')
    no_improve = 0
    val_metric_value = None
    test_cm = None
    test_cm_list = None
    test_acc = test_precision = test_recall = test_f1 = None

    for epoch in range(epochs):
        model.train()
        running_loss = 0
        total_train = 0
        correct_train, total_train_cls = 0, 0
        mae_sum_train, total_train_reg = 0.0, 0

        for imgs, lbls in train_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer.zero_grad()
            if mode == 'regression':
                out = model(imgs)
                preds_batch = out.squeeze()
                loss = criterion(preds_batch, lbls)
                mae_sum_train += torch.abs(preds_batch - lbls).sum().item()
                total_train_reg += lbls.size(0)
            else:
                if use_mixup:
                    imgs_mix, targets_a, targets_b, lam = mixup_data(imgs, lbls.long(), mixup_alpha)
                    out = model(imgs_mix)
                    loss = lam * criterion(out, targets_a) + (1 - lam) * criterion(out, targets_b)
                    preds_batch = out.argmax(dim=1)
                    if balance_penalty > 0:
                        p_mean = torch.softmax(out, dim=1)[:, 1].mean()
                        loss = loss + balance_penalty * torch.abs(p_mean - 0.5)
                    correct_train += (
                        lam * (preds_batch == targets_a).sum().item()
                        + (1 - lam) * (preds_batch == targets_b).sum().item()
                    )
                else:
                    out = model(imgs)
                    loss = focal_loss(out, lbls.long(), gamma=focal_gamma) if use_focal else criterion(out, lbls.long())
                    preds_batch = out.argmax(dim=1)
                    correct_train += (preds_batch == lbls.long()).sum().item()
                total_train_cls += lbls.size(0)

            loss.backward()
            optimizer.step()
            if ema: ema.update(model)
            running_loss += loss.item() * imgs.size(0)
            total_train += imgs.size(0)

        model.eval()
        running_val = 0
        preds_list, targs_list = [], []
        correct_val, total_val = 0, 0
        if ema: ema.apply_shadow(model)
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs, lbls = imgs.to(device), lbls.to(device)
                out = model(imgs)
                if mode == 'regression':
                    loss = criterion(out.squeeze(), lbls)
                    running_val += loss.item() * imgs.size(0)
                    preds_list.append(out.squeeze().cpu().numpy())
                    targs_list.append(lbls.cpu().numpy())
                else:
                    loss = focal_loss(out, lbls.long(), gamma=focal_gamma) if use_focal else criterion(out, lbls.long())
                    running_val += loss.item() * imgs.size(0)
                    preds = out.argmax(dim=1)
                    correct_val += (preds == lbls.long()).sum().item()
                    total_val += lbls.size(0)
        if ema: ema.restore(model)

        train_loss = running_loss / max(total_train, 1)
        val_loss = running_val / max(len(val_ds), 1)
        history_train_loss.append(train_loss)
        history_val_loss.append(val_loss)

        if mode == 'regression':
            train_mae = mae_sum_train / max(total_train_reg, 1)
            history_train_mae.append(train_mae)
            if preds_list:
                preds = np.concatenate(preds_list)
                targets = np.concatenate(targs_list)
                val_metric_value = mean_absolute_error(targets, preds)
                history_val_mae.append(val_metric_value)
        else:
            train_acc = correct_train / max(total_train_cls, 1) if total_train_cls else 0.0
            history_train_acc.append(train_acc)
            if total_val:
                val_metric_value = correct_val / total_val
                history_val_acc.append(val_metric_value)

        print(f"Epoch {epoch+1}/{epochs}: Train Loss {train_loss:.4f}, Val Loss {val_loss:.4f}")

        improved = False
        if mode == 'classification' and val_metric_value is not None:
            if val_metric_value > best_val_metric:
                improved = True
                best_val_metric = val_metric_value
        elif mode == 'regression' and val_loss < best_val_metric:
            improved = True
            best_val_metric = val_loss

        if improved:
            best_epoch = epoch + 1
            no_improve = 0
            best_state = {k: v.cpu() for k, v in (ema.shadow if (ema and ema.shadow) else model.state_dict()).items()}
        else:
            no_improve += 1
            if early_stop_patience and no_improve >= early_stop_patience:
                print(f"Early stopping na época {epoch+1}. Melhor época: {best_epoch}")
                break
        scheduler.step()

        model.load_state_dict(best_state, strict=False)

    history_train_mae_denorm = history_train_mae
    history_val_mae_denorm = history_val_mae
    if mode == 'regression' and age_scaler is not None:
        mae_scale = age_scaler.scale_[0]
        history_train_mae_denorm = [mae * mae_scale for mae in history_train_mae]
        history_val_mae_denorm = [mae * mae_scale for mae in history_val_mae]
        print(f"MAE (orig): train={history_train_mae_denorm[-1]:.4f} | val={history_val_mae_denorm[-1]:.4f}")

    if history_train_loss:
        epochs_range = range(1, len(history_train_loss) + 1)
        fig = plt.figure(figsize=(10, 4))
        ax1 = fig.add_subplot(121)
        ax1.plot(epochs_range, history_train_loss, 'b-', label='Treino')
        ax1.plot(epochs_range, history_val_loss, 'r-', label='Validação')
        ax1.set_title("Loss")
        ax1.set_xlabel("Época")
        ax1.legend(); ax1.grid(True, alpha=0.3)

        ax2 = fig.add_subplot(122)
        if mode == 'classification':
            if history_train_acc: ax2.plot(epochs_range, history_train_acc, 'b-', label='Treino')
            if history_val_acc: ax2.plot(epochs_range, history_val_acc, 'r-', label='Validação')
            ax2.set_title("Acurácia")
        else:
            if history_train_mae_denorm: ax2.plot(epochs_range, history_train_mae_denorm, 'b-', label='Treino')
            if history_val_mae_denorm: ax2.plot(epochs_range, history_val_mae_denorm, 'r-', label='Validação')
            ax2.set_title("MAE (anos)")
        ax2.set_xlabel("Época")
        ax2.legend(); ax2.grid(True, alpha=0.3)

        curves_name = f"densenet_{mode}_learning_curves.png"
        fig.tight_layout()
        fig.savefig(OUTPUT_DIR / curves_name, dpi=300, bbox_inches='tight')
        plt.close(fig)

    torch.save(model.state_dict(), OUTPUT_DIR / f"densenet_{mode}.pth")
    if best_state is not None:
        torch.save(best_state, OUTPUT_DIR / f"densenet_{mode}_bestval.pth")

    val_metric_value = None
    test_cm = None
    train_mae_orig = val_mae_orig = test_mae_orig = None
    train_r2 = val_r2 = test_r2 = train_rmse = val_rmse = test_rmse = None

    if mode == 'regression' and age_scaler is not None:
        model.eval()
        all_preds = {'train': [], 'val': [], 'test': []}
        all_true = {'train': [], 'val': [], 'test': []}
        loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}
        with torch.no_grad():
            for split, loader in loaders.items():
                for imgs, ages in loader:
                    imgs = imgs.to(device)
                    preds = model(imgs).squeeze()
                    all_preds[split].extend(np.atleast_1d(preds.cpu().numpy()))
                    all_true[split].extend(np.atleast_1d(ages.numpy()))

        for k in all_preds:
            all_preds[k] = age_scaler.inverse_transform(np.array(all_preds[k]).reshape(-1, 1)).flatten()
            all_true[k] = age_scaler.inverse_transform(np.array(all_true[k]).reshape(-1, 1)).flatten()

        train_mae_orig = mean_absolute_error(all_true['train'], all_preds['train'])
        train_r2 = r2_score(all_true['train'], all_preds['train'])
        train_rmse = np.sqrt(mean_squared_error(all_true['train'], all_preds['train']))
        val_mae_orig = mean_absolute_error(all_true['val'], all_preds['val'])
        val_r2 = r2_score(all_true['val'], all_preds['val'])
        val_rmse = np.sqrt(mean_squared_error(all_true['val'], all_preds['val']))
        test_mae_orig = mean_absolute_error(all_true['test'], all_preds['test'])
        test_r2 = r2_score(all_true['test'], all_preds['test'])
        test_rmse = np.sqrt(mean_squared_error(all_true['test'], all_preds['test']))
        val_metric_value = test_mae_orig

        fig_scatter = plt.figure(figsize=(8, 7))
        ax = fig_scatter.add_subplot(111)
        ax.scatter(all_true['test'], all_preds['test'], alpha=0.6, s=80, c='green',
                   edgecolors='darkgreen', linewidths=0.5)
        min_val = min(all_true['test'].min(), all_preds['test'].min())
        max_val = max(all_true['test'].max(), all_preds['test'].max())
        ax.text(0.05, 0.95,
                f'R² = {test_r2:.4f}\n'
                f'MAE = {test_mae_orig:.2f} anos\n'
                f'RMSE = {test_rmse:.2f} anos\n'
                f'N = {len(all_true['test'])} amostras',
                transform=ax.transAxes, fontsize=12, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.9,
                         edgecolor='darkgreen', linewidth=2))
        ax.set_xlabel('Idade Real (anos)', fontsize=13, fontweight='bold')
        ax.set_ylabel('Idade Predita (anos)', fontsize=13, fontweight='bold')
        ax.set_title('Teste: Predito vs Real', fontsize=14, fontweight='bold', pad=15)
        ax.legend(loc='lower right', fontsize=11, framealpha=0.9)
        ax.grid(True, alpha=0.3, linestyle='--')
    if mode == 'classification':
        model.eval()
        test_cm_list = None
        y_true_test, y_pred_test = [], []
        with torch.no_grad():
            for imgs, lbls in test_loader:
                imgs, lbls = imgs.to(device), lbls.to(device)
                out = model(imgs)
                preds = out.argmax(dim=1)
                y_true_test.append(lbls.cpu().numpy())
                y_pred_test.append(preds.cpu().numpy())
        if y_true_test:
            y_true_test = np.concatenate(y_true_test)
            y_pred_test = np.concatenate(y_pred_test)
            test_cm = confusion_matrix(y_true_test, y_pred_test)
            test_acc = accuracy_score(y_true_test, y_pred_test)
            test_cm_list = test_cm.tolist()
            test_precision = precision_score(y_true_test, y_pred_test, average='binary', zero_division=0)
            test_recall = recall_score(y_true_test, y_pred_test, average='binary', zero_division=0)
            test_f1 = f1_score(y_true_test, y_pred_test, average='binary', zero_division=0)
            val_metric_value = test_acc

            fig_cm, ax = plt.subplots(figsize=(6,5))
            plot_confusion_matrix(ax, test_cm, ['Nondemented', 'Demented'], "Teste")
            fig_cm.tight_layout()
            fig_cm.savefig(OUTPUT_DIR / f"confusion_densenet_{mode}.png", dpi=300, bbox_inches='tight')
            plt.close(fig_cm)

    learning_curves = {
        'train_loss': history_train_loss,
        'val_loss': history_val_loss,
    }
    if mode == 'classification':
        learning_curves['train_acc'] = history_train_acc
        learning_curves['val_acc'] = history_val_acc
    else:
        learning_curves['train_mae'] = history_train_mae_denorm
        learning_curves['val_mae'] = history_val_mae_denorm

    training_time = time.time() - start_time

    exp_payload = {
        'model': f'DenseNet_{mode}',
        'epochs': len(history_train_loss),
        'batch_size': batch_size,
        'learning_rate': lr,
        'train_loss': float(history_train_loss[-1]) if history_train_loss else None,
        'val_loss': float(history_val_loss[-1]) if history_val_loss else None,
        'learning_curves': learning_curves,
        'training_time_seconds': float(training_time),
        'best_params': {
            'epochs': epochs,
            'batch_size': batch_size,
            'learning_rate': lr,
        }
    }
    if mode == 'classification':
        exp_payload['best_val_accuracy'] = float(best_val_metric) if best_val_metric != -float('inf') else None
        exp_payload['best_epoch'] = best_epoch
        if val_metric_value is not None:
            exp_payload['test_accuracy'] = float(val_metric_value)
        if test_precision is not None:
            exp_payload['test_precision'] = float(test_precision)
            exp_payload['test_recall'] = float(test_recall)
            exp_payload['test_f1'] = float(test_f1)
        if test_cm_list is not None:
            exp_payload['test_confusion_matrix'] = test_cm_list
            exp_payload['test_classes'] = ['Nondemented', 'Demented']
        if 'best_thr' in locals() and best_thr:
            exp_payload['best_threshold'] = best_thr
            exp_payload['threshold_candidates'] = threshold_metrics
    else:
        exp_payload.update({
            'type': 'regression',
            'test_mae': float(test_mae_orig) if test_mae_orig is not None else None,
            'train_mae': float(train_mae_orig) if train_mae_orig is not None else None,
            'val_mae': float(val_mae_orig) if val_mae_orig is not None else None,
            'train_r2': float(train_r2) if train_r2 is not None else None,
            'val_r2': float(val_r2) if val_r2 is not None else None,
            'test_r2': float(test_r2) if test_r2 is not None else None,
            'train_rmse': float(train_rmse) if train_rmse is not None else None,
            'val_rmse': float(val_rmse) if val_rmse is not None else None,
            'test_rmse': float(test_rmse) if test_rmse is not None else None,
        })

    if export_embeddings:
        try:
            def _export_embeddings(split_name, dataset_obj):
                if len(dataset_obj) == 0:
                    return
                loader = DataLoader(dataset_obj, batch_size=batch_size, shuffle=False)
                emb_list, target_list, ids = [], [], []
                model.eval()
                idx_offset = 0
                with torch.no_grad():
                    for imgs, lbls in loader:
                        imgs = imgs.to(device)
                        feats = model.features(imgs)
                        feats = F.relu(feats, inplace=False)
                        feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(feats.size(0), -1)
                        emb_list.append(feats.cpu().numpy())
                        target_list.append(lbls.cpu().numpy())
                        rows = dataset_obj.df.iloc[idx_offset: idx_offset + len(lbls)]
                        ids.extend(rows.get('MRI_ID', rows.index).tolist())
                        idx_offset += len(lbls)
                emb_arr = np.concatenate(emb_list)
                tgt_arr = np.concatenate(target_list)
                df_emb = pd.DataFrame(emb_arr)
                df_emb.insert(0, 'MRI_ID', ids)
                if mode == 'regression' and 'age' in dataset_obj.df.columns:
                    df_emb['target'] = dataset_obj.df.loc[:len(df_emb)-1, 'age'].values
                else:
                    df_emb['target'] = tgt_arr
                out_path = OUTPUT_DIR / f"densenet_embeddings_{mode}_{split_name}.csv"
                df_emb.to_csv(out_path, index=False)
            _export_embeddings('train', train_ds)
            _export_embeddings('val', val_ds)
            _export_embeddings('test', test_ds)
        except Exception as e:
            print(f"Falha ao exportar embeddings: {e}")

    save_experiment(exp_payload)
    print("Treino DenseNet finalizado.", exp_payload)
    return exp_payload







### Refinamento com RL (DenseNet)

Porta a lógica de `refine_densenet_with_rl` para rodar sem Tkinter: define ambiente PPO simplificado, treina em micro-épocas sobre um subset e salva o melhor checkpoint (densenet_classification_rl_best.pth) + política (rl_policy_densenet.pth) e histórico JSON.

In [None]:
# Refinamento por RL da DenseNet: PPO simplificado, ambiente de micro-treinos e treino final com melhores hiperparâmetros.

import math
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.distributions import Categorical

class TrainHistoryWriter:
    def __init__(self, out_dir):
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(exist_ok=True)
    def save(self, history):
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        path = self.out_dir / f"train_history_rl_{ts}.json"
        path.write_text(json.dumps(history, indent=2))
        return path

class ActorCritic(torch.nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.shared = torch.nn.Sequential(
            torch.nn.Linear(state_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
        )
        self.policy_head = torch.nn.Linear(64, action_dim)
        self.value_head = torch.nn.Linear(64, 1)
    def forward(self, x):
        x = self.shared(x)
        return self.policy_head(x), self.value_head(x)

class PPOAgent:
    def __init__(self, state_dim, action_dim, device, lr=3e-4, gamma=0.99, clip_eps=0.2, epochs=4, batch_size=64):
        self.device = device
        self.gamma = gamma
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.memory = []
    def select_action(self, state):
        state_t = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        logits, value = self.policy(state_t)
        dist = Categorical(logits=logits)
        action = dist.sample()
        return int(action.item()), dist.log_prob(action), value.squeeze(0)
    def store(self, state, action, log_prob, value, reward, done):
        self.memory.append({
            'state': torch.tensor(state, dtype=torch.float32),
            'action': torch.tensor(action),
            'log_prob': log_prob.detach(),
            'value': value.detach(),
            'reward': torch.tensor(reward, dtype=torch.float32),
            'done': torch.tensor(done, dtype=torch.float32),
        })
    def _compute_returns_adv(self, gamma):
        returns, advs = [], []
        R = 0
        for step in reversed(self.memory):
            R = step['reward'] + gamma * R * (1 - step['done'])
            returns.insert(0, R)
        returns = torch.stack(returns)
        values = torch.stack([m['value'] for m in self.memory]).squeeze(-1)
        advs = returns - values
        advs = (advs - advs.mean()) / (advs.std() + 1e-8)
        return returns.detach(), advs.detach()
    def update(self):
        if not self.memory:
            return {}
        states = torch.stack([m['state'] for m in self.memory]).to(self.device)
        actions = torch.stack([m['action'] for m in self.memory]).to(self.device)
        old_log_probs = torch.stack([m['log_prob'] for m in self.memory]).to(self.device)
        returns, advantages = self._compute_returns_adv(self.gamma)
        returns = returns.to(self.device)
        advantages = advantages.to(self.device)
        losses = []
        for _ in range(self.epochs):
            logits, values = self.policy(states)
            dist = Categorical(logits=logits)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()
            ratio = (new_log_probs - old_log_probs).exp()
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = F.mse_loss(values.squeeze(-1), returns)
            loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())
        self.memory = []
        return {'loss': float(sum(losses) / max(len(losses), 1))}

class DenseNetRefineEnv:
    def __init__(self, train_loader, val_loader, device, base_checkpoint, class_weights=None, micro_epochs=1, max_batches_per_epoch=3):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.base_checkpoint = Path(base_checkpoint)
        self.class_weights = class_weights
        self.micro_epochs = micro_epochs
        self.max_batches = max_batches_per_epoch
        self.actions = [
            {'name': 'lr_up', 'lr_scale': 1.5},
            {'name': 'lr_down', 'lr_scale': 0.7},
            {'name': 'dropout_up', 'dropout_delta': 0.05},
            {'name': 'dropout_down', 'dropout_delta': -0.05},
            {'name': 'mixup_toggle', 'mixup_toggle': True},
            {'name': 'label_smoothing_toggle', 'ls_toggle': True},
        ]
        self.state_dim = 5
        self.action_dim = len(self.actions)
        self.best_val_acc = 0.0
        self.last_val_acc = 0.0
        self.last_val_loss = 0.0
        self.state = {}
        self.model = None
        self.reset()
    def _build_model(self):
        model = build_densenet(mode='classification', dropout_rate=self.state['dropout']).to(self.device)
        if self.base_checkpoint.exists():
            try:
                state_dict = torch.load(self.base_checkpoint, map_location=self.device)
                model.load_state_dict(state_dict, strict=False)
            except Exception as e:
                print(f"Falha ao carregar checkpoint base: {e}")
        if self.state.get('freeze_backbone') and hasattr(model, 'features'):
            for p in model.features.parameters():
                p.requires_grad = False
        return model
    def reset(self):
        self.state = {
            'lr': 1e-4,
            'weight_decay': 1e-4,
            'dropout': 0.3,
            'label_smoothing': 0.05,
            'mixup_alpha': 0.4,
            'freeze_backbone': False,
            'class_balance': False,
        }
        self.model = self._build_model()
        return self._as_vec()
    def _apply_action(self, action_idx):
        act = self.actions[action_idx]
        if 'lr_scale' in act:
            self.state['lr'] = float(min(5e-3, max(1e-6, self.state['lr'] * act['lr_scale'])))
        if 'dropout_delta' in act:
            self.state['dropout'] = float(min(0.8, max(0.05, self.state['dropout'] + act['dropout_delta'])))
        if act.get('mixup_toggle'):
            self.state['mixup_alpha'] = 0.0 if self.state['mixup_alpha'] > 0 else 0.4
        if act.get('ls_toggle'):
            self.state['label_smoothing'] = 0.0 if self.state['label_smoothing'] > 0 else 0.05
        self.model = self._build_model()
    def _as_vec(self):
        return [
            self.state['lr'],
            self.state['weight_decay'],
            self.state['dropout'],
            self.state['label_smoothing'],
            self.state['mixup_alpha'],
        ]
    def _run_micro_train(self, model):
        optimizer = optim.Adam(model.parameters(), lr=self.state['lr'], weight_decay=self.state['weight_decay'])
        criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.state['label_smoothing'], weight=self.class_weights)
        use_mixup = self.state['mixup_alpha'] > 0
        train_loss = 0
        total, correct = 0, 0
        model.train()
        for epoch in range(self.micro_epochs):
            for batch_idx, (imgs, lbls) in enumerate(self.train_loader):
                if batch_idx >= self.max_batches:
                    break
                imgs, lbls = imgs.to(self.device), lbls.to(self.device)
                optimizer.zero_grad()
                if use_mixup:
                    imgs_mix, targets_a, targets_b, lam = mixup_data(imgs, lbls.long(), self.state['mixup_alpha'])
                    out = model(imgs_mix)
                    loss = lam * criterion(out, targets_a) + (1 - lam) * criterion(out, targets_b)
                    preds = out.argmax(dim=1)
                    correct += (
                        lam * (preds == targets_a).sum().item() +
                        (1 - lam) * (preds == targets_b).sum().item()
                    )
                    total += lbls.size(0)
                else:
                    out = model(imgs)
                    loss = criterion(out, lbls.long())
                    preds = out.argmax(dim=1)
                    correct += (preds == lbls.long()).sum().item()
                    total += lbls.size(0)
                loss.backward()
                optimizer.step()
                train_loss += loss.item() * imgs.size(0)
        train_loss = train_loss / max(total, 1)
        train_acc = correct / max(total, 1)
        return train_loss, train_acc
    def _eval(self, model, loader):
        model.eval()
        criterion = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        loss_sum = 0
        total, correct = 0, 0
        with torch.no_grad():
            for batch_idx, (imgs, lbls) in enumerate(loader):
                if batch_idx >= self.max_batches:
                    break
                imgs, lbls = imgs.to(self.device), lbls.to(self.device)
                out = model(imgs)
                loss = criterion(out, lbls.long())
                preds = out.argmax(dim=1)
                correct += (preds == lbls.long()).sum().item()
                total += lbls.size(0)
                loss_sum += loss.item() * imgs.size(0)
        acc = correct / max(total, 1)
        loss_mean = loss_sum / max(total, 1)
        return acc, loss_mean
    def step(self, action_idx):
        self._apply_action(action_idx)
        train_loss, train_acc = self._run_micro_train(self.model)
        val_acc, val_loss = self._eval(self.model, self.val_loader)
        self.last_val_acc = val_acc
        self.last_val_loss = val_loss
        reward = val_acc - 0.1 * val_loss
        if val_acc > self.best_val_acc:
            self.best_val_acc = val_acc
            best_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
            self.best_state = best_state
            self.best_hparams = dict(self.state)
        info = {
            'train_loss': float(train_loss),
            'train_acc': float(train_acc),
            'val_loss': float(val_loss),
            'val_acc': float(val_acc),
            'state': dict(self.state),
            'action': self.actions[action_idx]['name'],
        }
        return self._as_vec(), float(reward), info

def evaluate_full_model(model, loader, device):
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()
    total, correct, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            out = model(imgs)
            loss = criterion(out, lbls.long())
            preds = out.argmax(dim=1)
            correct += (preds == lbls.long()).sum().item()
            total += lbls.size(0)
            loss_sum += loss.item() * imgs.size(0)
    acc = correct / max(total, 1)
    loss_mean = loss_sum / max(total, 1)
    return {'acc': acc, 'loss': loss_mean}

def refine_densenet_with_rl(episodes=4, horizon=4, micro_epochs=1, train_subset=120, val_subset=80, base_checkpoint=None):
    if not EXAM_SPLIT_CSV.exists():
        raise FileNotFoundError("Crie o dataset (create_exam_level_dataset) antes de rodar o RL.")
    if base_checkpoint is None:
        base_checkpoint = OUTPUT_DIR / "densenet_classification.pth"
    df = pd.read_csv(EXAM_SPLIT_CSV)
    df_train = df[df['split'] == 'train'].copy()
    df_val = df[df['split'] == 'validation'].copy()
    df_test = df[df['split'] == 'test'].copy()
    if df_train.empty or df_val.empty:
        raise ValueError("Splits de treino/validação vazios para classificação.")

    device = select_device()
    train_tf, val_tf = build_transforms()

    def _sample(df_split, n):
        if len(df_split) <= n:
            return df_split
        return df_split.sample(n=n, random_state=42)

    df_train_small = _sample(df_train, train_subset)
    df_val_small = _sample(df_val, val_subset)

    batch_small = 8
    train_loader_small = DataLoader(
        MRIDataset(df_train_small, train_tf, DATASET_DIR.parent, 'original_path', 'Final_Group'),
        batch_size=batch_small, shuffle=True
    )
    val_loader_small = DataLoader(
        MRIDataset(df_val_small, val_tf, DATASET_DIR.parent, 'original_path', 'Final_Group'),
        batch_size=batch_small, shuffle=False
    )

    class_weights = None
    class_counts = df_train['Final_Group'].value_counts()
    if len(class_counts) >= 1:
        total = class_counts.sum()
        w0 = total / (2 * class_counts.get('Nondemented', max(class_counts.max(), 1)))
        w1 = total / (2 * class_counts.get('Demented', max(class_counts.max(), 1)))
        class_weights = torch.tensor([w0, w1], dtype=torch.float32, device=device)

    env = DenseNetRefineEnv(
        train_loader=train_loader_small,
        val_loader=val_loader_small,
        device=device,
        base_checkpoint=base_checkpoint,
        class_weights=class_weights,
        micro_epochs=micro_epochs,
        max_batches_per_epoch=3,
    )
    agent = PPOAgent(state_dim=env.state_dim, action_dim=env.action_dim, device=device)
    history = {"episodes": [], "actions": env.actions}

    state = env.reset()
    for ep in range(episodes):
        ep_reward = 0.0
        ep_steps = []
        for _ in range(horizon):
            action_idx, log_prob, value_est = agent.select_action(state)
            next_state, reward, info = env.step(action_idx)
            agent.store(state, action_idx, log_prob, value_est, reward, done=False)
            ep_reward += reward
            ep_steps.append(info)
            state = next_state
        update_stats = agent.update()
        history["episodes"].append({
            "episode": ep + 1,
            "reward_sum": float(ep_reward),
            "last_val_acc": float(env.last_val_acc),
            "last_val_loss": float(env.last_val_loss),
            "best_val_acc": float(env.best_val_acc),
            "steps": ep_steps,
            "update": update_stats,
        })
        state = env.reset()

    best_state = getattr(env, 'best_state', None)
    best_hparams = getattr(env, 'best_hparams', None)
    if best_state is None:
        best_state = {k: v.cpu() for k, v in env.model.state_dict().items()}
        best_hparams = env.state

    eval_model = env._build_model()
    eval_model.load_state_dict(best_state, strict=False)
    eval_model = eval_model.to(device)

    val_loader_full = DataLoader(
        MRIDataset(df_val, val_tf, DATASET_DIR.parent, 'original_path', 'Final_Group'),
        batch_size=16, shuffle=False
    )
    test_loader_full = DataLoader(
        MRIDataset(df_test, val_tf, DATASET_DIR.parent, 'original_path', 'Final_Group'),
        batch_size=16, shuffle=False
    )

    val_metrics = evaluate_full_model(eval_model, val_loader_full, device)
    test_metrics = evaluate_full_model(eval_model, test_loader_full, device)

    def _collect_preds(loader):
        eval_model.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device)
                labels = labels.to(device).long()
                logits = eval_model(imgs)
                preds = logits.argmax(dim=1)
                y_true.append(labels.cpu().numpy())
                y_pred.append(preds.cpu().numpy())
        if not y_true:
            return np.array([]), np.array([])
        return np.concatenate(y_true), np.concatenate(y_pred)

    y_val, y_val_pred = _collect_preds(val_loader_full)
    y_test, y_test_pred = _collect_preds(test_loader_full)
    val_cm = confusion_matrix(y_val, y_val_pred) if y_val.size else None
    test_cm = confusion_matrix(y_test, y_test_pred) if y_test.size else None

    best_model_path = OUTPUT_DIR / "densenet_classification_rl_best.pth"
    torch.save(best_state, best_model_path)
    policy_path = OUTPUT_DIR / "rl_policy_densenet.pth"
    torch.save(agent.policy.state_dict(), policy_path)

    history_writer = TrainHistoryWriter(OUTPUT_DIR)
    history["meta"] = {
        "episodes": episodes,
        "horizon": horizon,
        "micro_epochs": micro_epochs,
        "train_subset": len(df_train_small),
        "val_subset": len(df_val_small),
        "best_val_acc": float(env.best_val_acc),
        "base_checkpoint": Path(base_checkpoint).name,
    }
    history_file = history_writer.save(history)

    exp_payload = {
        'model': 'DenseNet_classification_RL',
        'episodes': episodes,
        'horizon': horizon,
        'micro_epochs': micro_epochs,
        'train_subset': len(df_train_small),
        'val_subset': len(df_val_small),
        'best_val_acc': float(env.best_val_acc),
        'val_accuracy': float(val_metrics.get("acc", 0.0)),
        'test_accuracy': float(test_metrics.get("acc", 0.0)),
        'val_loss': float(val_metrics.get("loss", 0.0)),
        'test_loss': float(test_metrics.get("loss", 0.0)),
        'history_file': history_file.name,
        'best_model_path': best_model_path.name,
        'policy_path': policy_path.name,
        'best_hparams': best_hparams,
    }
    if val_cm is not None:
        exp_payload['val_confusion_matrix'] = val_cm.tolist()
        exp_payload['val_classes'] = ['Nondemented', 'Demented']
    if test_cm is not None:
        exp_payload['test_confusion_matrix'] = test_cm.tolist()
        exp_payload['test_classes'] = ['Nondemented', 'Demented']

    if best_hparams:
        try:
            print("Treinando DenseNet final com hiperparâmetros do RL (split completo)...")
            train_densenet(mode='classification', hparams=best_hparams)
        except Exception as e:
            print(f"Falha ao treinar modelo final com hparams do RL: {e}")

    save_experiment(exp_payload)
    print("RL concluído.", exp_payload)
    return exp_payload


### Visualizações offline (heatmap, scatterplots, t-SNE/UMAP)

Funções para gerar gráficos sem Tkinter: scatterplots de descritores, heatmap de correlação e reduções 2D de embeddings (t-SNE/UMAP).

In [None]:
# Visualizações offline: scatterplots de descritores, heatmap de correlação, t-SNE/UMAP de embeddings.

from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
try:
    import umap  # type: ignore
    UMAP_AVAILABLE = True
except ImportError:
    UMAP_AVAILABLE = False

def _scatter_by_group(df, x_col, y_col, hue_col, title, out_path):
    plot_df = df.dropna(subset=[x_col, y_col])
    if plot_df.empty:
        return None
    fig, ax = plt.subplots(figsize=(8, 6))
    if pd.api.types.is_numeric_dtype(plot_df[hue_col]) and plot_df[hue_col].nunique() > 10:
        scatter = ax.scatter(plot_df[x_col], plot_df[y_col], c=plot_df[hue_col], cmap='viridis', alpha=0.8, s=40)
        cbar = fig.colorbar(scatter, ax=ax)
        cbar.set_label(hue_col)
    else:
        sns.scatterplot(data=plot_df, x=x_col, y=y_col, hue=hue_col, palette='tab10', alpha=0.85, s=50, edgecolor='white', linewidth=0.3, ax=ax)
        ax.legend(title=hue_col)
    ax.set_xlabel(x_col)
    ax.set_ylabel(y_col)
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    return out_path

def generate_descriptor_scatterplots(limit_pairs=None, hue_col='Group'):
    if not DESCRIPTORS_CSV.exists():
        raise FileNotFoundError(f"Descritores não encontrados: {DESCRIPTORS_CSV}")
    if not CSV_DEMOGRAPHIC.exists():
        raise FileNotFoundError(f"CSV demográfico não encontrado: {CSV_DEMOGRAPHIC}")

    df_desc = pd.read_csv(DESCRIPTORS_CSV)
    df_demo = pd.read_csv(CSV_DEMOGRAPHIC, sep=';', decimal=',')
    df_demo = df_demo.rename(columns=lambda x: x.strip())
    if 'MRI ID' in df_demo.columns:
        df_demo = df_demo.rename(columns={'MRI ID': 'MRI_ID'})
    merged = pd.merge(df_desc, df_demo[[c for c in df_demo.columns if c in ['MRI_ID', hue_col]]], on='MRI_ID', how='left')
    if hue_col not in merged.columns:
        raise ValueError(f"Coluna de agrupamento '{hue_col}' não encontrada após merge.")

    descriptor_cols = [c for c in df_desc.columns if c not in ['viable', 'MRI_ID', 'Subject_ID', 'segmented_path']]
    descriptor_cols = [c for c in descriptor_cols if merged[c].dtype.kind in 'biufc']
    pairs = []
    for i in range(len(descriptor_cols)):
        for j in range(i + 1, len(descriptor_cols)):
            pairs.append((descriptor_cols[i], descriptor_cols[j]))
    if limit_pairs:
        pairs = pairs[:limit_pairs]

    out_paths = []
    for x_col, y_col in pairs:
        merged[x_col] = pd.to_numeric(merged[x_col], errors='coerce')
        merged[y_col] = pd.to_numeric(merged[y_col], errors='coerce')
        title = f"{y_col} vs {x_col}"
        out_path = OUTPUT_DIR / f"scatter_{x_col}_vs_{y_col}.png"
        p = _scatter_by_group(merged, x_col, y_col, hue_col, title, out_path)
        if p:
            out_paths.append(p)
    print(f"Scatterplots gerados: {len(out_paths)} em {OUTPUT_DIR}")
    return out_paths

def plot_correlation_heatmap(split_path=EXAM_SPLIT_CSV, out_path=None):
    split_path = Path(split_path)
    out_path = Path(out_path) if out_path else OUTPUT_DIR / "correlation_heatmap.png"
    if not split_path.exists():
        raise FileNotFoundError(f"Crie o dataset combinado primeiro: {split_path}")
    df = pd.read_csv(split_path)
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    numeric_cols = [c for c in numeric_cols if c not in ['converted']]
    if not numeric_cols:
        raise ValueError("Nenhuma coluna numérica para correlacionar.")

    df_numeric = df[numeric_cols].dropna(axis=1, how='all')
    if 'Final_Group' in df.columns:
        class_dummies = pd.get_dummies(df['Final_Group'], prefix='Class')
        df_numeric = pd.concat([df_numeric, class_dummies], axis=1)

    corr = df_numeric.corr()
    n_features = len(corr.columns)
    fig_size = max(8, min(20, n_features * 0.6))
    fig, ax = plt.subplots(figsize=(fig_size, fig_size))
    sns.heatmap(corr, cmap='coolwarm', center=0, square=True, ax=ax, cbar_kws={'shrink': 0.5})
    ax.set_title('Heatmap de Correlação')
    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Heatmap salvo em {out_path}")
    return out_path

def run_tsne_umap(emb_path, target_col='target', out_prefix=None):
    emb_path = Path(emb_path)
    out_prefix = Path(out_prefix) if out_prefix else OUTPUT_DIR / emb_path.stem
    if not emb_path.exists():
        raise FileNotFoundError(f"CSV de embeddings não encontrado: {emb_path}")

    df = pd.read_csv(emb_path)
    if target_col not in df.columns:
        raise ValueError(f"Coluna alvo '{target_col}' não encontrada no CSV.")
    meta_cols = {'MRI_ID', target_col}
    feature_cols = [c for c in df.columns if c not in meta_cols]
    if not feature_cols:
        raise ValueError("Nenhuma coluna de feature encontrada.")

    X = df[feature_cols].values
    y = df[target_col].values
    X_scaled = StandardScaler().fit_transform(X)
    n_samples = len(X_scaled)
    perplexity = max(5, min(30, n_samples - 1))

    print(f"Executando t-SNE (perplexity={perplexity}) para {n_samples} amostras...")
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, init='pca')
    X_tsne = tsne.fit_transform(X_scaled)
    df_tsne = pd.DataFrame({'x': X_tsne[:, 0], 'y': X_tsne[:, 1], 'target': y})
    tsne_path = out_prefix.with_suffix('.tsne.png')
    _scatter_by_group(df_tsne, 'x', 'y', 'target', 't-SNE dos embeddings', tsne_path)

    if UMAP_AVAILABLE:
        print("Executando UMAP...")
        reducer = umap.UMAP(n_components=2, random_state=42)
        X_umap = reducer.fit_transform(X_scaled)
        df_umap = pd.DataFrame({'x': X_umap[:, 0], 'y': X_umap[:, 1], 'target': y})
        umap_path = out_prefix.with_suffix('.umap.png')
        _scatter_by_group(df_umap, 'x', 'y', 'target', 'UMAP dos embeddings', umap_path)
    else:
        print("UMAP não disponível (instale 'umap-learn' para habilitar).")

    return out_prefix


### Exemplos de uso (execute conforme necessidade)

In [None]:

# 1) Segmentar (ajuste limit para um dry-run rápido)
# segment_all_images(limit=2, overwrite=False)

# 2) Criar dataset combinado
# merged = create_exam_level_dataset()
# display(merged.head())

# 3) Treinos clássicos
# svm_metrics = train_svm_classifier()
# xgb_metrics = train_xgboost_regressor()

# 4) DenseNet (reduza max_epochs para teste rápido)
# densenet_cls = train_densenet(mode='classification', max_epochs=2)
# densenet_reg = train_densenet(mode='regression', max_epochs=2)


# 5) Visualizações
# generate_descriptor_scatterplots(limit_pairs=12)
# plot_correlation_heatmap()
# run_tsne_umap('output/densenet_embeddings_classification_test.csv', target_col='target')

# 6) Refinamento via RL (usa checkpoint de classificação)
# refine_densenet_with_rl(episodes=2, horizon=2, micro_epochs=1, train_subset=40, val_subset=30)
