In [31]:
# load err analysis data

import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import os
import json


class ErrorAnalysisDataset(Dataset):
    """
    PyTorch Dataset for loading image data with predictions and ground truth.
    """
    def __init__(self, dataset_root, pred_split, img_root_dir=None, transform=None):
        """
        Args:
            csv_file (str): Path to the CSV file with annotations.
            root_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                                            on a sample.
        """
        self.data = pd.read_csv(os.path.join(dataset_root, "dataset.csv"))
        self.img_root_dir = img_root_dir
        self.transform = transform
        
        with open(os.path.join(dataset_root, 'mapping.json'), 'r') as f:
            self.mapping = json.load(f)
            idx_to_mapping = list(self.mapping)
            
        predictions = pd.read_csv(os.path.join(dataset_root, f"pred_splits/{pred_split}"), header=None, names=['pred'])
        self.data['pred']=predictions['pred'].values
        self.data['pred']=self.data['pred'].apply(lambda pred: self.mapping[idx_to_mapping[pred]])
        self.data['gt']=self.data['gt'].apply(lambda pred: self.mapping[idx_to_mapping[pred]])

        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Load image
        img_path = os.path.join(self.img_root_dir, self.data.iloc[idx]['img_id'])
        image = Image.open(img_path).convert('RGB')  # Ensure RGB
        
        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)
        
        # Extract metadata
        attribute = self.data.iloc[idx]['attribute']
        gt_code = self.data.iloc[idx]['gt_code']
        gt = self.data.iloc[idx]['gt']
        pred = self.data.iloc[idx]['pred']

        # Return a dictionary with the image and metadata
        sample = {
            'image': image,
            # 'attribute': attribute,
            # 'gt_code': gt_code,
            # 'gt': gt,
            # 'pred': pred
        }

        return sample


In [32]:
# create a Context class that you can add to and read from

In [65]:
# confusion matrix analysis module
import numpy as np
class InitialAnalysis:
    def __init__(self, df, prediction_col, ground_truth_col, k=5):
        # Get unique classes
        classes = sorted(df[ground_truth_col].unique())
        
        # Initialize confusion matrix
        confusion_matrix = pd.DataFrame(
            np.zeros((len(classes), len(classes)), dtype=int),
            index=classes,
            columns=classes
        )
        
        # Populate the confusion matrix
        for _, row in df.iterrows():
            actual = row[ground_truth_col]
            predicted = row[prediction_col]
            confusion_matrix.loc[actual, predicted] += 1
        
        # Extract non-diagonal elements
        errors = []
        for actual in classes:
            for predicted in classes:
                if actual != predicted and confusion_matrix.loc[actual, predicted] > 0:
                    errors.append(((actual, predicted), confusion_matrix.loc[actual, predicted]))

        # Sort errors by count and take top k
        top_k_errors = sorted(errors, key=lambda x: x[1], reverse=True)[:k]

        self.confusion_matrix = confusion_matrix
        self.top_k_errors_pred_conditional = top_k_errors
        self.k = k
        self.classes = classes
        
    def human_readable_topk_pred_conditional_errors(self):
        errors_nl = []
        for (actual, predicted), err_count in self.top_k_errors_pred_conditional:
            errors_nl.append(f"The actual class is '{actual}', however model incorrectly predicts '{predicted}' {err_count} times")
            
        return '\n '.join(errors_nl)
    
    def human_readable_topk_errors_gt(self):
        marginal_errs = []
        for gt in self.classes:
            marginal_errs.append((gt, (self.confusion_matrix.loc[gt].sum()-self.confusion_matrix.loc[gt, gt]).item()))

        # select top k
        marginal_errs = sorted(marginal_errs, key=lambda x: x[1], reverse=True)[:self.k]
        
        return f"The top five marginal errors are for these classes: {marginal_errs}"


In [66]:
# test it out
dataset = ErrorAnalysisDataset(dataset_root='../mock_data_creation/mock_data', pred_split='split_0.txt')

# predictions = pd.read_csv("../mock_data_creation/mock_data/pred_splits/split_0.txt", header=None, names=['pred'])
# dataset = pd.read_csv("../mock_data_creation/mock_data/dataset.csv")
# dataset['pred']=predictions['pred'].values 
# InitialAnalysis(split_0_df, 

dataset.data.head()

Unnamed: 0.1,Unnamed: 0,img_id,attribute,gt_code,gt,pred
0,0,n01443537/art_0.jpg,art,n01443537,goldfish,missile
1,1,n01443537/art_1.jpg,art,n01443537,goldfish,goldfish
2,2,n01443537/art_10.jpg,art,n01443537,goldfish,soccer_ball
3,3,n01443537/art_11.jpg,art,n01443537,goldfish,goldfish
4,4,n01443537/art_12.jpg,art,n01443537,goldfish,goldfish


In [67]:
analysis = InitialAnalysis(dataset.data, prediction_col='pred', ground_truth_col='gt')
analysis.human_readable_topk_errors_gt()

"The top five marginal errors are for these classes: [('mushroom', 125), ('toucan', 97), ('flamingo', 96), ('bee', 81), ('jellyfish', 80)]"

In [68]:
analysis.human_readable_topk_pred_conditional_errors()

"The actual class is 'ant', however model incorrectly predicts 'burrito' 4 times\n The actual class is 'hotdog', however model incorrectly predicts 'west_highland_white_terrier' 4 times\n The actual class is 'mushroom', however model incorrectly predicts 'pig' 4 times\n The actual class is 'tank', however model incorrectly predicts 'parachute' 4 times\n The actual class is 'African_chameleon', however model incorrectly predicts 'shih_tzu' 3 times"