# Função Auxiliar para Avaliação Few-Shot

Função para salvar e visualizar resultados de experimentos de aprendizado few-shot.

In [None]:
%%writefile helpers.py

import json
import os
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

## Função para Experimentos Episódicos

In [None]:
def save_fewshot_results(
    experiment_name,
    model_name,
    metric_name,
    normalization,
    accuracies,
    n_way,
    n_shot,
    n_query,
    n_episodes,
    device,
    all_predictions=None,
    all_labels=None,
    class_names=None,
    results_dir="/content/drive/MyDrive/pdi/resultados"
):
    """Salva resultados de experimentos few-shot com múltiplos episódios."""
    
    os.makedirs(results_dir, exist_ok=True)
    
    exp_name = f"{experiment_name}_{n_way}way_{n_shot}shot_{n_episodes}ep"
    exp_dir = os.path.join(results_dir, exp_name)
    
    counter = 1
    original_exp_dir = exp_dir
    while os.path.exists(exp_dir):
        exp_dir = f"{original_exp_dir}_v{counter}"
        counter += 1
    
    os.makedirs(exp_dir, exist_ok=True)
    
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies)
    min_acc = np.min(accuracies)
    max_acc = np.max(accuracies)
    median_acc = np.median(accuracies)
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    metrics_dict = None
    report_dict = None
    if all_predictions is not None and all_labels is not None:
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        
        from sklearn.metrics import precision_score, recall_score, f1_score
        
        precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
        
        metrics_dict = {
            "precision_macro": float(precision),
            "recall_macro": float(recall),
            "f1_macro": float(f1)
        }
        
        if class_names is not None:
            report_dict = classification_report(
                all_labels,
                all_predictions,
                target_names=class_names,
                digits=4,
                output_dict=True,
                zero_division=0
            )
    
    # 1. Salvar configuração do experimento
    config = {
        "experiment_name": experiment_name,
        "timestamp": timestamp,
        "model": model_name,
        "metric": metric_name,
        "normalization": normalization,
        "n_way": n_way,
        "n_shot": n_shot,
        "n_query": n_query,
        "n_episodes": n_episodes,
        "mean_accuracy": float(mean_acc),
        "std_accuracy": float(std_acc),
        "min_accuracy": float(min_acc),
        "max_accuracy": float(max_acc),
        "median_accuracy": float(median_acc),
        "device": str(device),
        "all_accuracies": [float(acc) for acc in accuracies]
    }
    
    if metrics_dict:
        config.update(metrics_dict)
    
    with open(os.path.join(exp_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=4)
    
    # 2. Salvar relatório em texto
    with open(os.path.join(exp_dir, "report.txt"), "w") as f:
        f.write(f"Experimento: {experiment_name}\n")
        f.write(f"Data: {timestamp}\n")
        f.write(f"{'='*60}\n\n")
        f.write(f"Configuração:\n")
        f.write(f"  Modelo: {model_name}\n")
        f.write(f"  Métrica: {metric_name}\n")
        f.write(f"  Normalização: {normalization}\n")
        f.write(f"  N-way: {n_way} classes por episódio\n")
        f.write(f"  K-shot: {n_shot} exemplos/classe\n")
        f.write(f"  Query: {n_query} exemplos/classe para teste\n")
        f.write(f"  Episódios: {n_episodes}\n")
        f.write(f"  Device: {device}\n\n")
        f.write(f"{'='*60}\n")
        f.write(f"Resultados:\n")
        f.write(f"  Acurácia Média: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%\n")
        f.write(f"  Acurácia Mínima: {min_acc*100:.2f}%\n")
        f.write(f"  Acurácia Máxima: {max_acc*100:.2f}%\n")
        f.write(f"  Mediana: {median_acc*100:.2f}%\n")
        if metrics_dict:
            f.write(f"\nMétricas Agregadas (macro):\n")
            f.write(f"  Precision: {metrics_dict['precision_macro']*100:.2f}%\n")
            f.write(f"  Recall: {metrics_dict['recall_macro']*100:.2f}%\n")
            f.write(f"  F1-Score: {metrics_dict['f1_macro']*100:.2f}%\n")
        f.write(f"{'='*60}\n\n")
        f.write(f"Acurácias por episódio:\n")
        for i, acc in enumerate(accuracies, 1):
            f.write(f"  Episódio {i:2d}: {acc*100:.2f}%\n")
    
    # 3. Salvar relatório detalhado por classe
    if report_dict is not None and class_names is not None:
        with open(os.path.join(exp_dir, "per_class_metrics.json"), "w") as f:
            json.dump(report_dict, f, indent=4)
        
        with open(os.path.join(exp_dir, "per_class_report.txt"), "w") as f:
            f.write(f"{'='*70}\n")
            f.write(f"MÉTRICAS DETALHADAS POR CLASSE\n")
            f.write(f"{'='*70}\n\n")
            f.write(f"Experimento: {experiment_name}\n")
            f.write(f"Data: {timestamp}\n")
            f.write(f"Modelo: {model_name}\n")
            f.write(f"Configuração: {n_way}-way {n_shot}-shot ({n_episodes} episódios)\n\n")
            f.write(f"{'='*70}\n\n")
            
            f.write(classification_report(
                all_labels,
                all_predictions,
                target_names=class_names,
                digits=4,
                zero_division=0
            ))
            
            f.write(f"\n{'='*70}\n")
            f.write(f"INTERPRETAÇÃO:\n")
            f.write(f"{'='*70}\n\n")
            f.write(f"• Precision: De todas as predições de uma classe, quantas estavam corretas\n")
            f.write(f"• Recall: De todas as amostras de uma classe, quantas foram identificadas\n")
            f.write(f"• F1-Score: Média harmônica entre precision e recall\n")
            f.write(f"• Support: Número de amostras reais de cada classe testadas\n\n")
            f.write(f"• macro avg: Média simples das métricas de todas as classes\n")
            f.write(f"• weighted avg: Média ponderada pelo suporte de cada classe\n\n")
            f.write(f"{'='*70}\n")
        
        # Gerar tabela visual de métricas
        accuracy_overall = (all_predictions == all_labels).mean()
        
        table_data_classes = []
        for class_name in class_names:
            metrics = report_dict[class_name]
            display_name = class_name.upper() if len(class_name) <= 5 else class_name.title()
            table_data_classes.append([
                display_name,
                f"{metrics['precision']*100:.0f}",
                f"{metrics['recall']*100:.0f}",
                f"{metrics['f1-score']*100:.0f}",
                f"{int(metrics['support'])}"
            ])
        
        table_data_summary = [
            ['Acurácia (Accuracy)', '-', '-', f"{accuracy_overall*100:.0f}", f"{len(all_labels)}"],
            ['Macro Average', 
             f"{report_dict['macro avg']['precision']*100:.0f}",
             f"{report_dict['macro avg']['recall']*100:.0f}",
             f"{report_dict['macro avg']['f1-score']*100:.0f}",
             f"{len(all_labels)}"],
            ['Weighted Average',
             f"{report_dict['weighted avg']['precision']*100:.0f}",
             f"{report_dict['weighted avg']['recall']*100:.0f}",
             f"{report_dict['weighted avg']['f1-score']*100:.0f}",
             f"{len(all_labels)}"]
        ]
        
        fig = plt.figure(figsize=(11, len(class_names) * 0.55 + 3.5))
        ax = fig.add_subplot(111)
        ax.axis('off')
        
        col_labels = ['Classe', 'Precision', 'Recall', 'F1-Score', 'Suporte (Imagens)']
        
        table1 = ax.table(cellText=table_data_classes, colLabels=col_labels,
                         cellLoc='center', loc='upper center',
                         colWidths=[0.28, 0.16, 0.16, 0.16, 0.24],
                         bbox=[0, 0.35, 1, 0.65])
        
        table1.auto_set_font_size(False)
        table1.set_fontsize(12)
        
        for i in range(len(col_labels)):
            cell = table1[(0, i)]
            cell.set_facecolor('#E8DCC8')
            cell.set_text_props(weight='bold', color='black', fontsize=12)
            cell.set_edgecolor('#D3D3D3')
            cell.set_linewidth(0.5)
        
        for i in range(1, len(table_data_classes) + 1):
            for j in range(len(col_labels)):
                cell = table1[(i, j)]
                cell.set_facecolor('#FAF5EF')
                cell.set_edgecolor('#D3D3D3')
                cell.set_linewidth(0.5)
                
                if j == 0:
                    cell.set_text_props(weight='bold', fontsize=11)
                else:
                    cell.set_text_props(fontsize=11)
        
        ax.text(0.5, 0.31, 'MÉDIAS GERAIS', 
               ha='center', va='center', 
               fontsize=12, weight='bold',
               bbox=dict(boxstyle='round,pad=0.5', 
                        facecolor='#E8DCC8', 
                        edgecolor='#D3D3D3',
                        linewidth=0.5))
        
        table2 = ax.table(cellText=table_data_summary,
                         cellLoc='center', loc='lower center',
                         colWidths=[0.28, 0.16, 0.16, 0.16, 0.24],
                         bbox=[0, 0.0, 1, 0.26])
        
        table2.auto_set_font_size(False)
        table2.set_fontsize(11)
        
        for i in range(len(table_data_summary)):
            for j in range(len(col_labels)):
                cell = table2[(i, j)]
                cell.set_facecolor('#FAF5EF')
                cell.set_edgecolor('#D3D3D3')
                cell.set_linewidth(0.5)
                
                if j == 0:
                    cell.set_text_props(weight='bold', fontsize=11)
                else:
                    cell.set_text_props(fontsize=11)
        
        plt.savefig(os.path.join(exp_dir, "metrics_table.png"), 
                   dpi=300, bbox_inches='tight', facecolor='white', 
                   edgecolor='none', pad_inches=0.3)
        plt.close()
    
    # 4. Matriz de confusão
    unique_labels = np.unique(np.concatenate([all_labels, all_predictions]))
    n_classes = len(unique_labels)
    
    if class_names is not None:
        class_names_for_plot = class_names
    else:
        class_names_for_plot = [f"Classe {i}" for i in unique_labels]
    
    cm = confusion_matrix(all_labels, all_predictions, labels=unique_labels)
    
    fig, ax = plt.subplots(figsize=(10, 7))
    
    sns.heatmap(cm, annot=True, fmt='d', 
                cmap='Blues',
                xticklabels=class_names_for_plot,
                yticklabels=class_names_for_plot,
                annot_kws={'size': 11},
                cbar=False,
                square=False,
                linewidths=0.5,
                linecolor='white',
                ax=ax)
    
    ax.set_ylabel('Real', fontsize=11)
    ax.set_xlabel('Predito', fontsize=11)
    ax.set_title(f'Matriz de Confusão', fontsize=12, pad=12)
    
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.yticks(rotation=0, fontsize=11)
    
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "confusion_matrix.png"), 
                dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    # 5. Salvar predições
    predictions_df = {
        "true_label": all_labels.tolist(),
        "predicted_label": all_predictions.tolist(),
        "correct": (all_labels == all_predictions).tolist()
    }
    
    if class_names is not None:
        predictions_df["true_class"] = [class_names[int(i)] for i in all_labels]
        predictions_df["predicted_class"] = [class_names[int(i)] for i in all_predictions]
    
    with open(os.path.join(exp_dir, "predictions.json"), "w") as f:
        json.dump(predictions_df, f, indent=4)
    
    # 6. Gráfico de acurácia por episódio
    plt.figure(figsize=(12, 6))
    episodes = range(1, n_episodes + 1)
    plt.plot(episodes, [acc*100 for acc in accuracies], marker='o', linewidth=2, markersize=4, alpha=0.6)
    plt.axhline(y=mean_acc*100, color='r', linestyle='--', linewidth=2, label=f'Média: {mean_acc*100:.2f}%')
    plt.fill_between(episodes, 
                     [(mean_acc - std_acc)*100]*n_episodes, 
                     [(mean_acc + std_acc)*100]*n_episodes, 
                     alpha=0.2, color='red', label=f'±1 std: {std_acc*100:.2f}%')
    plt.xlabel('Episódio', fontsize=12)
    plt.ylabel('Acurácia (%)', fontsize=12)
    plt.title(f'{model_name} + {metric_name}\n{n_way}-way {n_shot}-shot - {n_episodes} episódios', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "accuracy_per_episode.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 7. Histograma de acurácias
    plt.figure(figsize=(10, 6))
    plt.hist(accuracies, bins=15, color='steelblue', edgecolor='black', alpha=0.7)
    plt.axvline(mean_acc, color='r', linestyle='--', linewidth=2, label=f'Média: {mean_acc*100:.2f}%')
    plt.axvline(median_acc, color='g', linestyle='--', linewidth=2, label=f'Mediana: {median_acc*100:.2f}%')
    plt.xlabel('Acurácia', fontsize=12)
    plt.ylabel('Frequência', fontsize=12)
    plt.title(f'Distribuição de Acurácias\n{model_name} - {n_way}-way {n_shot}-shot', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "accuracy_distribution.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 8. Imprimir resumo
    print(f"\n{'='*70}")
    print(f"Resultados salvos em: {exp_dir}")
    print(f"{'='*70}")
    print(f"Configuração: {n_way}-way {n_shot}-shot ({n_episodes} episódios)")
    print(f"Acurácia: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%")
    if metrics_dict:
        print(f"Precision (macro): {metrics_dict['precision_macro']*100:.2f}%")
        print(f"Recall (macro): {metrics_dict['recall_macro']*100:.2f}%")
        print(f"F1-Score (macro): {metrics_dict['f1_macro']*100:.2f}%")
    print(f"\nArquivos gerados:")
    print(f"  - config.json")
    print(f"  - report.txt")
    if report_dict is not None:
        print(f"  - per_class_metrics.json")
        print(f"  - per_class_report.txt")
        print(f"  - metrics_table.png")
    print(f"  - confusion_matrix.png")
    print(f"  - predictions.json")
    print(f"  - accuracy_per_episode.png")
    print(f"  - accuracy_distribution.png")
    print(f"  - accuracy_boxplot.png")
    print(f"{'='*70}\n")
    
    return exp_dir