In [1]:
%config Completer.use_jedi = False

Perform ensemble evaluation from logits files

## Import stuff and create arguments

In [2]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [3]:
import sys
sys.path.append('../')

from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from utils.datasets import dataset_selector
from utils.neural import *
from utils.metrics import compute_accuracy

In [42]:
def ensemble_evaluation(logits_dir="logits", prefix="test", ensemble_strategy=["avg, vote"]):
    
    # Check ensemble strategies are okey
    available_strategies = ["avg", "vote"]
    if len(ensemble_strategy) == 0: assert False, "Please specify a ensemble strategy"
    ensemble_strategy = [ensemble_strategy] if isinstance(ensemble_strategy, str) else ensemble_strategy
    for strategy in ensemble_strategy:
        if strategy not in available_strategies:
            assert False, f"Unknown strategy {strategy}"
    
    # Load logits
    logits_paths = []
    for subdir, dirs, files in os.walk(logits_dir):
        for file in files:
            file_path = os.path.join(subdir, file)
            if f"{prefix}_logits" in file_path:
                logits_paths.append(file_path)
                
    if not len(logits_paths):
        assert False, f"Could not find any file at subdirectoreis of '{logits_dir}' with prefix '{prefix}'"
    
    # Get logits and labels
    logits_list, labels_list = [], []
    for lp in logits_paths:
        logits_name = "/".join(lp.split("/")[-2:])
        info = torch.load(lp)
        logits = info["logits"].cpu()
        labels = info["labels"].cpu()

        logits_accuracy = compute_accuracy(labels, logits)
        print(f"{logits_name}: {logits_accuracy}")
        logits_list.append(logits)
        labels_list.append(labels)
    
    # logits_list shape: torch.Size([N, 10000, 10]) (CIFAR10 example)
    logits_list = torch.stack(logits_list)
    
    # -- Check if al labels has the same order for all logits --
    labels = labels_list[0]
    for indx, label_list in enumerate(labels_list[1:]):
        # Si alguno difiere del primero es que no es igual al resto tampoco
        if not torch.all(labels.eq(label_list)):
            assert False, f"Labels list does not match!"
    
    # -- Ensemble Strategies ---
    if "avg" in ensemble_strategy:
        softmax = nn.Softmax(dim=2)
        probs_list = softmax(logits_list)
        probs_avg = probs_list.sum(dim=0) / len(probs_list)
        probs_avg_accuracy = compute_accuracy(labels, probs_avg)
        print(f"--- Avg Strategy: {probs_avg_accuracy} ---")
    
    if "vote" in ensemble_strategy:
        _, vote_list = torch.max(logits_list.data, dim=2)
        vote_list = torch.nn.functional.one_hot(vote_list)
        vote_list = vote_list.sum(dim=0)
        _, vote_list = torch.max(vote_list.data, dim=1)
        vote_accuracy = (vote_list == labels).sum().item() / len(labels)
        print(f"--- Vote Strategy: {vote_accuracy} ---")

In [44]:
ensemble_evaluation("../logits", prefix="val", ensemble_strategy=["avg", "vote"])

model1/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9478
model2/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9488
model3/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.95
model4/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9472
model5/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.949
--- Avg Strategy: 0.9582 ---
--- Vote Strategy: 0.9564 ---


In [45]:
ensemble_evaluation("../logits", prefix="test", ensemble_strategy=["avg", "vote"])

model1/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9482
model2/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9438
model3/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9469
model4/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9448
model5/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9458
--- Avg Strategy: 0.9562 ---
--- Vote Strategy: 0.9557 ---
