# Imports and Helper Functions

In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import os
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from ffm import *


In [2]:
# Grayscale to RGB transform
class GrayscaleToRGB(object):
    """From https://www.kaggle.com/code/cafalena/caltech101-pytorch-deep-learning"""
    def __call__(self, img):
        if img.mode == 'L':
            img = img.convert("RGB")
        return img

transform = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),]
    )

tranform_preprocess = transforms.Compose(
    [
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )

transform_all = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )

# Load Caltech101 dataset
def load_data(root='caltech_data', transform=transform):
    dataset = datasets.Caltech101(root='caltech_data', download=False, transform=transform)

    return dataset

# Split dataset into training and testing
def split_data(dataset):
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))
    return train_dataset, test_dataset

# Load pretrained ResNet model
def load_model(model_name: str):
    model = getattr(models, model_name)(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 101)
    return model

# FFM Analysis

In [3]:
# Load data
dataset = load_data(transform=transform)
train_dataset, test_dataset = split_data(dataset)

model_checkpoints = os.listdir("checkpoints")
model_checkpoints = {
    'resnet18': [f for f in model_checkpoints if 'resnet18' in f],
}

In [4]:
# Some examples from the dataset where our model should ideally show:
idxs = {
    'low_ffm': 190,  # low ffm score (highly confident, incorrect prediction, with consistent explanation for the incorrect prediction)
    'moderate_ffm': 8,  # moderate ffm score (correct prediction, but with incorrect/inconsistent explanation for ground truth)
    'high_ffm': 6,  # high ffm score (highly confident, correct prediction, with consistent explanation for the ground truth)
}

In [5]:
# Load model
for key, i in idxs.items():
    batch = tranform_preprocess(test_dataset[i][0]).unsqueeze(0).to('cuda')
    gt_labels = [test_dataset[i][1]]
    batch = torch.vstack([batch]*4)
    gt_labels = gt_labels*4
    for MODEL in ['resnet18']: #, 'resnet34', 'resnet50', 'resnet101']:
        model = load_model(MODEL).to('cuda')
        if model_checkpoints[MODEL]:
            model.load_state_dict(torch.load(f"checkpoints/{model_checkpoints[MODEL][0]}/model.pth"))
            print(f"Loaded model from {model_checkpoints[MODEL][0]}")
        else:
            print("No model checkpoint found, loading pretrained model")

        model = model.eval()
        target_layer = model.layer4[-1]

        # Compute FFM
        ffm_val = compute_ffm(
            model=model,
            img=batch,
            gt_label=gt_labels,
            target_layer=target_layer,
            reduce='none',
            top_k=5,
            output_softmax=True,
            device='cuda'
        )

        print(f'{i}th FFM for {MODEL}: {ffm_val}')
        print(f'(is {key})')
        print('*'*50)




Loaded model from resnet18-20240917-161535


100%|██████████| 16/16 [00:00<00:00, 80.64it/s]
100%|██████████| 16/16 [00:00<00:00, 92.33it/s]
100%|██████████| 16/16 [00:00<00:00, 69.92it/s] 
100%|██████████| 16/16 [00:00<00:00, 129.25it/s]
100%|██████████| 16/16 [00:00<00:00, 53.61it/s]
100%|██████████| 16/16 [00:00<00:00, 55.81it/s]
100%|██████████| 16/16 [00:00<00:00, 132.79it/s]
100%|██████████| 16/16 [00:00<00:00, 73.43it/s]
100%|██████████| 16/16 [00:00<00:00, 94.55it/s]
100%|██████████| 16/16 [00:00<00:00, 77.73it/s]
100%|██████████| 16/16 [00:00<00:00, 95.84it/s]
100%|██████████| 16/16 [00:00<00:00, 81.03it/s]
100%|██████████| 16/16 [00:00<00:00, 60.15it/s]
100%|██████████| 16/16 [00:00<00:00, 59.71it/s]
100%|██████████| 16/16 [00:00<00:00, 71.24it/s]
100%|██████████| 16/16 [00:00<00:00, 72.86it/s]
100%|██████████| 16/16 [00:00<00:00, 71.50it/s]
100%|██████████| 16/16 [00:00<00:00, 59.02it/s]
100%|██████████| 16/16 [00:00<00:00, 44.54it/s]
100%|██████████| 16/16 [00:00<00:00, 59.72it/s]
100%|██████████| 16/16 [00:00<00:00, 

190th FFM for resnet18: [0.12140695 0.12140695 0.12140695 0.12140695]
(is low_ffm)
**************************************************
Loaded model from resnet18-20240917-161535


100%|██████████| 16/16 [00:00<00:00, 56.18it/s] 
100%|██████████| 16/16 [00:00<00:00, 64.29it/s] 
100%|██████████| 16/16 [00:00<00:00, 55.48it/s]
100%|██████████| 16/16 [00:00<00:00, 55.87it/s]
100%|██████████| 16/16 [00:00<00:00, 64.29it/s]
100%|██████████| 16/16 [00:00<00:00, 57.90it/s]
100%|██████████| 16/16 [00:00<00:00, 92.61it/s]
100%|██████████| 16/16 [00:00<00:00, 51.45it/s]
100%|██████████| 16/16 [00:00<00:00, 74.99it/s]
100%|██████████| 16/16 [00:00<00:00, 94.46it/s]
100%|██████████| 16/16 [00:00<00:00, 74.47it/s] 
100%|██████████| 16/16 [00:00<00:00, 69.52it/s]
100%|██████████| 16/16 [00:00<00:00, 51.25it/s]
100%|██████████| 16/16 [00:00<00:00, 75.61it/s]
100%|██████████| 16/16 [00:00<00:00, 133.47it/s]
100%|██████████| 16/16 [00:00<00:00, 53.69it/s]
100%|██████████| 16/16 [00:00<00:00, 96.74it/s] 
100%|██████████| 16/16 [00:00<00:00, 78.37it/s]
100%|██████████| 16/16 [00:00<00:00, 74.19it/s]
100%|██████████| 16/16 [00:00<00:00, 89.25it/s]
100%|██████████| 16/16 [00:00<00:00

8th FFM for resnet18: [0.5 0.5 0.5 0.5]
(is moderate_ffm)
**************************************************
Loaded model from resnet18-20240917-161535


100%|██████████| 16/16 [00:00<00:00, 94.27it/s] 
100%|██████████| 16/16 [00:00<00:00, 133.12it/s]
100%|██████████| 16/16 [00:00<00:00, 97.00it/s]
100%|██████████| 16/16 [00:00<00:00, 40.73it/s]
100%|██████████| 16/16 [00:00<00:00, 127.44it/s]
100%|██████████| 16/16 [00:00<00:00, 132.18it/s]
100%|██████████| 16/16 [00:00<00:00, 93.20it/s]
100%|██████████| 16/16 [00:00<00:00, 74.71it/s]
100%|██████████| 16/16 [00:00<00:00, 96.88it/s] 
100%|██████████| 16/16 [00:00<00:00, 80.18it/s]
100%|██████████| 16/16 [00:00<00:00, 63.18it/s]
100%|██████████| 16/16 [00:00<00:00, 78.88it/s] 
100%|██████████| 16/16 [00:00<00:00, 130.49it/s]
100%|██████████| 16/16 [00:00<00:00, 33.19it/s]
100%|██████████| 16/16 [00:00<00:00, 94.81it/s]
100%|██████████| 16/16 [00:00<00:00, 96.27it/s]
100%|██████████| 16/16 [00:00<00:00, 44.80it/s]
100%|██████████| 16/16 [00:00<00:00, 134.82it/s]
100%|██████████| 16/16 [00:00<00:00, 95.01it/s]
100%|██████████| 16/16 [00:00<00:00, 61.80it/s]
100%|██████████| 16/16 [00:00<00

6th FFM for resnet18: [0.98805539 0.98805539 0.98805539 0.98805539]
(is high_ffm)
**************************************************





In [8]:
model(batch)

tensor([[ -5.4258,  -4.2526,  -6.8620,  -2.8648,  -7.3341,  -2.6373,  -5.6172,
          -8.0879,  -5.0348,  -8.5480,  -6.0507,  -6.2677,  -5.7136,  -7.2534,
          -6.9979,  -5.4202,  -7.2120,  -5.1941,  -2.6403,  10.6820,  -5.9640,
          -4.6383,  -6.0485,  -5.3464,  -5.2907,  -5.8720,  -7.9378,  -7.5920,
          -6.5769,  -5.8074,  -6.5410,  -7.6310,  -6.2503,  -5.0728,  -6.7955,
          -6.6904,  -6.3245,  -5.9103,  -4.1487,  -8.3158,  -2.8990,  -5.8587,
          -8.3853, -10.0053,  -6.4801,  -6.6867,  -6.0426,  -7.0230,  -5.7453,
          -6.4381,  -4.1114,  -7.5998,  -5.5940,  -7.9528,  -5.5267,  -2.1868,
          -8.2595,  -6.3859,  -5.5202,  -7.5451,  -7.2721,  -6.3467,  -7.2371,
          -7.1876,  -6.0976,  -6.0309,  -7.8099,  -7.1929,  -6.7158,  -7.9910,
          -7.5961,  -6.8759,  -7.4579,  -9.2192,  -6.1122,  -5.8065,  -8.4852,
          -9.6949,  -7.9894,  -6.4940,  -7.1667,  -7.0323,  -8.5821,  -8.3945,
          -4.6447,  -6.3502,  -7.4648,  -8.3188,  -4