In [None]:
%%writefile helpers.py


import json
import os
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix


def save_experiment_results(
    experiment_name,
    model_name,
    metric_name,
    normalization,
    y_true,
    y_pred,
    class_names,
    n_train,
    n_test,
    device,
    results_dir="/content/drive/MyDrive/pdi/resultados"
):
    """Salva todos os resultados de um experimento FSL."""
    
    # Criar pasta de resultados
    os.makedirs(results_dir, exist_ok=True)
    
    # Calcular N-way e K-shot
    num_classes = len(class_names)
    shots_per_class = n_train // num_classes
    
    # Nome do experimento: modelo_metrica_Nway_Kshot
    exp_name = f"{experiment_name}_{num_classes}way_{shots_per_class}shot"
    exp_dir = os.path.join(results_dir, exp_name)
    
    # Se já existe, adicionar sufixo
    counter = 1
    original_exp_dir = exp_dir
    while os.path.exists(exp_dir):
        exp_dir = f"{original_exp_dir}_v{counter}"
        counter += 1
    
    os.makedirs(exp_dir, exist_ok=True)
    
    # Calcular acurácia
    accuracy = (y_pred == y_true).mean()
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    # 1. Gerar relatório detalhado
    report_dict = classification_report(
        y_true, 
        y_pred,
        target_names=class_names,
        digits=3,
        output_dict=True,
        zero_division=0
    )
    
    # 2. Salvar configuração do experimento
    config = {
        "experiment_name": experiment_name,
        "timestamp": timestamp,
        "n_way": num_classes,
        "k_shot": shots_per_class,
        "model": model_name,
        "metric": metric_name,
        "normalization": normalization,
        "n_classes": num_classes,
        "n_train": n_train,
        "n_test": n_test,
        "classes": class_names,
        "accuracy": float(accuracy),
        "device": str(device)
    }
    
    with open(os.path.join(exp_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=4)
    
    # 3. Salvar métricas detalhadas
    with open(os.path.join(exp_dir, "metrics.json"), "w") as f:
        json.dump(report_dict, f, indent=4)
    
    # 4. Salvar relatório em texto
    with open(os.path.join(exp_dir, "report.txt"), "w") as f:
        f.write(f"Experimento: {experiment_name}\n")
        f.write(f"Data: {timestamp}\n")
        f.write(f"{'='*60}\n\n")
        f.write(f"Configuração:\n")
        f.write(f"  Modelo: {model_name}\n")
        f.write(f"  Métrica: {metric_name}\n")
        f.write(f"  Normalização: {normalization}\n")
        f.write(f"  N-way: {num_classes} classes\n")
        f.write(f"  K-shot: {shots_per_class} exemplos/classe\n")
        f.write(f"  Treino: {n_train} imagens ({shots_per_class} por classe)\n")
        f.write(f"  Teste: {n_test} imagens\n\n")
        f.write(f"{'='*60}\n")
        f.write(f"Acurácia: {accuracy*100:.2f}%\n")
        f.write(f"{'='*60}\n\n")
        f.write("Relatório de Classificação:\n\n")
        f.write(classification_report(y_true, y_pred, target_names=class_names, digits=3, zero_division=0))
    
    # 5. Matriz de Confusão
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.ylabel('Real')
    plt.xlabel('Predito')
    plt.title(f'{model_name} + {metric_name}\n{num_classes}-way {shots_per_class}-shot - Acurácia: {accuracy*100:.2f}%')
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "confusion_matrix.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 6. Salvar predições
    predictions_df = {
        "true_label": y_true.tolist(),
        "predicted_label": y_pred.tolist(),
        "true_class": [class_names[i] for i in y_true],
        "predicted_class": [class_names[i] for i in y_pred],
        "correct": (y_true == y_pred).tolist()
    }
    
    with open(os.path.join(exp_dir, "predictions.json"), "w") as f:
        json.dump(predictions_df, f, indent=4)
    
    # 7. Gráfico de acurácia por classe
    per_class_acc = []
    for i in range(num_classes):
        mask = y_true == i
        if mask.sum() > 0:
            acc = (y_pred[mask] == y_true[mask]).mean()
            per_class_acc.append(acc)
        else:
            per_class_acc.append(0)
    
    plt.figure(figsize=(12, 6))
    plt.bar(class_names, per_class_acc, color='steelblue')
    plt.xlabel('Classe')
    plt.ylabel('Acurácia')
    plt.title(f'Acurácia por Classe - {model_name} + {metric_name}\n{num_classes}-way {shots_per_class}-shot')
    plt.xticks(rotation=45, ha='right')
    plt.ylim([0, 1])
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "accuracy_per_class.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 8. Imprimir resumo
    print(f"\n{'='*60}")
    print(f"Resultados salvos em: {exp_dir}")
    print(f"{'='*60}")
    print(f"Configuração: {num_classes}-way {shots_per_class}-shot")
    print(f"Arquivos gerados:")
    print(f"  - config.json")
    print(f"  - metrics.json")
    print(f"  - report.txt")
    print(f"  - confusion_matrix.png")
    print(f"  - accuracy_per_class.png")
    print(f"  - predictions.json")
    print(f"{'='*60}\n")
    
    return exp_dir


def save_fewshot_results(
    experiment_name,
    model_name,
    metric_name,
    normalization,
    accuracies,
    n_way,
    n_shot,
    n_query,
    n_episodes,
    device,
    all_predictions=None,
    all_labels=None,
    results_dir="/content/drive/MyDrive/pdi/resultados"
):
    """Salva resultados de experimentos few-shot com múltiplos episódios."""
    
    # Criar pasta de resultados
    os.makedirs(results_dir, exist_ok=True)
    
    # Nome do experimento
    exp_name = f"{experiment_name}_{n_way}way_{n_shot}shot_{n_episodes}ep"
    exp_dir = os.path.join(results_dir, exp_name)
    
    # Se já existe, adicionar sufixo
    counter = 1
    original_exp_dir = exp_dir
    while os.path.exists(exp_dir):
        exp_dir = f"{original_exp_dir}_v{counter}"
        counter += 1
    
    os.makedirs(exp_dir, exist_ok=True)
    
    # Calcular estatísticas
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies)
    min_acc = np.min(accuracies)
    max_acc = np.max(accuracies)
    median_acc = np.median(accuracies)
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    # Calcular métricas agregadas (se disponível)
    metrics_dict = None
    if all_predictions is not None and all_labels is not None:
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        
        from sklearn.metrics import precision_score, recall_score, f1_score
        
        precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
        
        metrics_dict = {
            "precision_macro": float(precision),
            "recall_macro": float(recall),
            "f1_macro": float(f1)
        }
    
    # 1. Salvar configuração do experimento
    config = {
        "experiment_name": experiment_name,
        "timestamp": timestamp,
        "model": model_name,
        "metric": metric_name,
        "normalization": normalization,
        "n_way": n_way,
        "n_shot": n_shot,
        "n_query": n_query,
        "n_episodes": n_episodes,
        "mean_accuracy": float(mean_acc),
        "std_accuracy": float(std_acc),
        "min_accuracy": float(min_acc),
        "max_accuracy": float(max_acc),
        "median_accuracy": float(median_acc),
        "device": str(device),
        "all_accuracies": [float(acc) for acc in accuracies]
    }
    
    if metrics_dict:
        config.update(metrics_dict)
    
    with open(os.path.join(exp_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=4)
    
    # 2. Salvar relatório em texto
    with open(os.path.join(exp_dir, "report.txt"), "w") as f:
        f.write(f"Experimento: {experiment_name}\n")
        f.write(f"Data: {timestamp}\n")
        f.write(f"{'='*60}\n\n")
        f.write(f"Configuração:\n")
        f.write(f"  Modelo: {model_name}\n")
        f.write(f"  Métrica: {metric_name}\n")
        f.write(f"  Normalização: {normalization}\n")
        f.write(f"  N-way: {n_way} classes por episódio\n")
        f.write(f"  K-shot: {n_shot} exemplos/classe\n")
        f.write(f"  Query: {n_query} exemplos/classe para teste\n")
        f.write(f"  Episódios: {n_episodes}\n")
        f.write(f"  Device: {device}\n\n")
        f.write(f"{'='*60}\n")
        f.write(f"Resultados:\n")
        f.write(f"  Acurácia Média: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%\n")
        f.write(f"  Acurácia Mínima: {min_acc*100:.2f}%\n")
        f.write(f"  Acurácia Máxima: {max_acc*100:.2f}%\n")
        f.write(f"  Mediana: {median_acc*100:.2f}%\n")
        if metrics_dict:
            f.write(f"\nMétricas Agregadas (macro):\n")
            f.write(f"  Precision: {metrics_dict['precision_macro']*100:.2f}%\n")
            f.write(f"  Recall: {metrics_dict['recall_macro']*100:.2f}%\n")
            f.write(f"  F1-Score: {metrics_dict['f1_macro']*100:.2f}%\n")
        f.write(f"{'='*60}\n\n")
        f.write(f"Acurácias por episódio:\n")
        for i, acc in enumerate(accuracies, 1):
            f.write(f"  Episódio {i:2d}: {acc*100:.2f}%\n")
    
    # 3. Matriz de Confusão - SEMPRE gera se tiver predições
    # Descobrir classes únicas nas predições
    unique_labels = np.unique(np.concatenate([all_labels, all_predictions]))
    n_classes = len(unique_labels)
    
    # Nomes das classes (genéricos, já que as classes mudam entre episódios)
    class_names_generic = [f"Classe {i}" for i in unique_labels]
    
    cm = confusion_matrix(all_labels, all_predictions, labels=unique_labels)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names_generic,
                yticklabels=class_names_generic)
    plt.ylabel('Real', fontsize=12)
    plt.xlabel('Predito', fontsize=12)
    plt.title(f'{model_name} + {metric_name}\n{n_way}-way {n_shot}-shot - Acurácia: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%', fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "confusion_matrix.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Salvar predições
    predictions_df = {
        "true_label": all_labels.tolist(),
        "predicted_label": all_predictions.tolist(),
        "correct": (all_labels == all_predictions).tolist()
    }
    
    with open(os.path.join(exp_dir, "predictions.json"), "w") as f:
        json.dump(predictions_df, f, indent=4)
    
    # 5. Gráfico: Acurácia por episódio
    plt.figure(figsize=(12, 6))
    episodes = range(1, n_episodes + 1)
    plt.plot(episodes, [acc*100 for acc in accuracies], marker='o', linewidth=2, markersize=6)
    plt.axhline(y=mean_acc*100, color='r', linestyle='--', label=f'Média: {mean_acc*100:.2f}%')
    plt.fill_between(episodes, 
                     [(mean_acc - std_acc)*100]*n_episodes, 
                     [(mean_acc + std_acc)*100]*n_episodes, 
                     alpha=0.2, color='red', label=f'±1 std: {std_acc*100:.2f}%')
    plt.xlabel('Episódio', fontsize=12)
    plt.ylabel('Acurácia (%)', fontsize=12)
    plt.title(f'{model_name} + {metric_name}\n{n_way}-way {n_shot}-shot - {n_episodes} episódios', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "accuracy_per_episode.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 6. Histograma de acurácias
    plt.figure(figsize=(10, 6))
    plt.hist(accuracies, bins=15, color='steelblue', edgecolor='black', alpha=0.7)
    plt.axvline(mean_acc, color='r', linestyle='--', linewidth=2, label=f'Média: {mean_acc*100:.2f}%')
    plt.axvline(median_acc, color='g', linestyle='--', linewidth=2, label=f'Mediana: {median_acc*100:.2f}%')
    plt.xlabel('Acurácia', fontsize=12)
    plt.ylabel('Frequência', fontsize=12)
    plt.title(f'Distribuição de Acurácias\n{model_name} - {n_way}-way {n_shot}-shot', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "accuracy_distribution.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 7. Boxplot
    plt.figure(figsize=(8, 6))
    plt.boxplot(accuracies, vert=True, patch_artist=True,
                boxprops=dict(facecolor='lightblue', color='blue'),
                medianprops=dict(color='red', linewidth=2),
                whiskerprops=dict(color='blue'),
                capprops=dict(color='blue'))
    plt.ylabel('Acurácia', fontsize=12)
    plt.title(f'Distribuição de Acurácias - Boxplot\n{model_name} - {n_way}-way {n_shot}-shot', fontsize=14)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "accuracy_boxplot.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 8. Imprimir resumo
    print(f"\n{'='*60}")
    print(f"Resultados salvos em: {exp_dir}")
    print(f"{'='*60}")
    print(f"Configuração: {n_way}-way {n_shot}-shot ({n_episodes} episódios)")
    print(f"Acurácia: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%")
    if metrics_dict:
        print(f"Precision (macro): {metrics_dict['precision_macro']*100:.2f}%")
        print(f"Recall (macro): {metrics_dict['recall_macro']*100:.2f}%")
        print(f"F1-Score (macro): {metrics_dict['f1_macro']*100:.2f}%")
    print(f"Arquivos gerados:")
    print(f"  - config.json")
    print(f"  - report.txt")
    print(f"  - confusion_matrix.png")
    print(f"  - predictions.json")
    print(f"  - accuracy_per_episode.png")
    print(f"  - accuracy_distribution.png")
    print(f"  - accuracy_boxplot.png")
    print(f"{'='*60}\n")
    
    return exp_dir

Overwriting helpers.py
