In [1]:
import pickle
import numpy as np

with open(
    "/kaggle/input/vgg16-models-predictions-test-set/model_1_4_predictions.pkl", "rb"
) as f_pred1, open(
    "/kaggle/input/vgg16-models-predictions-test-set/model_1_4_probabilities.pkl", "rb"
) as f_prob1:
    predictions_1_4 = pickle.load(f_pred1)
    probabilities_1_4 = pickle.load(f_prob1)

with open(
    "/kaggle/input/vgg16-models-predictions-test-set/model_5_8_predictions.pkl", "rb"
) as f_pred2, open(
    "/kaggle/input/vgg16-models-predictions-test-set/model_5_8_probabilities.pkl", "rb"
) as f_prob2:
    predictions_5_8 = pickle.load(f_pred2)
    probabilities_5_8 = pickle.load(f_prob2)

with open(
    "/kaggle/input/vgg16-models-predictions-test-set/model_9_12_predictions.pkl", "rb"
) as f_pred3, open(
    "/kaggle/input/vgg16-models-predictions-test-set/model_9_12_probabilities.pkl", "rb"
) as f_prob3:
    predictions_9_12 = pickle.load(f_pred3)
    probabilities_9_12 = pickle.load(f_prob3)

In [6]:
from dataset.dataset import PlanetaryDataset
from utilities.training import VAL_TEST_TRANSFORM
from torch.utils.data import DataLoader

test_dataset = PlanetaryDataset(
    data_dir="/kaggle/input/gsoc-protoplanetary-disks/Test_Clean",
    csv_file="/kaggle/input/gsoc-protoplanetary-disks/test_info.csv",
    channels=list(range(30, 101)),
    transform=VAL_TEST_TRANSFORM,
)

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [9]:
all_predictions = {**predictions_1_4, **predictions_5_8, **predictions_9_12}
all_probabilities = {**probabilities_1_4, **probabilities_5_8, **probabilities_9_12}


all_labels = []
all_majority_preds = []
all_soft_probs = []

In [None]:
for i, batch in enumerate(test_loader):
    print(i)
    _, labels = batch
    all_labels.append(labels.item())

    # collect predictions from all models
    sample_preds = []
    sample_probs = []
    for model_path, preds in all_predictions.items():
        sample_preds.append(preds[i])
    for model_path, probs in all_probabilities.items():
        sample_probs.append(probs[i])

    # Hard Voting
    stacked_preds = np.stack(sample_preds)
    majority_pred = np.argmax(np.bincount(stacked_preds.flatten()))
    all_majority_preds.append(majority_pred)

    # Soft Voting
    stacked_probs = np.stack(sample_probs)
    avg_probs = np.mean(stacked_probs, axis=0)
    soft_pred = np.argmax(avg_probs)
    all_soft_probs.append(avg_probs)

### Hard and soft accuracy

In [None]:
# Calculate accuracy for hard voting
accuracy_hard = np.mean(np.array(all_majority_preds) == np.array(all_labels))
print(f"Hard Voting Accuracy: {accuracy_hard * 100:.2f}%")


all_soft_probs_array = np.array(all_soft_probs)

avg_soft_probs = np.mean(all_soft_probs_array, axis=0)

soft_pred_labels = np.argmax(avg_soft_probs, axis=1)
accuracy_soft = np.mean(soft_pred_labels == np.array(all_labels))
print(f"Soft Voting Accuracy: {accuracy_soft * 100:.2f}%")

### Save incorrect samples indices

In [14]:
incorrect_samples_hard = np.where(np.array(all_majority_preds) != np.array(all_labels))[
    0
]
incorrect_samples_soft = np.where(soft_pred_labels != np.array(all_labels))[0]

with open("incorrect_samples_hard.pkl", "wb") as f_hard, open(
    "incorrect_samples_soft.pkl", "wb"
) as f_soft:
    pickle.dump(incorrect_samples_hard, f_hard)
    pickle.dump(incorrect_samples_soft, f_soft)

### Wrong predictions visualization

In [None]:
import matplotlib.pyplot as plt
import pickle

with open("/kaggle/working/incorrect_samples_hard.pkl", "rb") as f_hard, open(
    "/kaggle/working/incorrect_samples_soft.pkl", "rb"
) as f_soft:
    incorrect_samples_hard = pickle.load(f_hard)
    incorrect_samples_soft = pickle.load(f_soft)


def plot_images(dataset, indices, title, num_cols=5):
    num_rows = len(indices) // num_cols + 1
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3 * num_rows))

    for i, idx in enumerate(indices):
        image, _ = dataset[idx]
        ax = axes[i // num_cols, i % num_cols]
        ax.imshow(image.permute(1, 2, 0).cpu().numpy())
        ax.set_title(f"Sample {idx}")
        ax.axis("off")

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


plot_images(test_dataset, incorrect_samples_hard, "Hard Voting Misclassifications")
plot_images(test_dataset, incorrect_samples_soft, "Soft Voting Misclassifications")