# Imports and Helper Functions

In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import os
from torchvision import datasets, transforms
from ffm import *


In [2]:
# Grayscale to RGB transform
class GrayscaleToRGB(object):
    """From https://www.kaggle.com/code/cafalena/caltech101-pytorch-deep-learning"""
    def __call__(self, img):
        if img.mode == 'L':
            img = img.convert("RGB")
        return img

transform = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),]
    )

tranform_preprocess = transforms.Compose(
    [
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )

transform_all = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )

# Load Caltech101 dataset
def load_data(root='caltech_data', transform=transform):
    dataset = datasets.Caltech101(root='caltech_data', download=False, transform=transform)

    return dataset

# Split dataset into training and testing
def split_data(dataset):
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))
    return train_dataset, test_dataset

# Load pretrained ResNet model
def load_model(model_name: str):
    model = getattr(models, model_name)(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 101)
    return model

# FFM Analysis

In [3]:
# Load data
dataset = load_data(transform=transform)
train_dataset, test_dataset = split_data(dataset)

model_checkpoints = os.listdir("checkpoints")
model_checkpoints = {
    'resnet18': [f for f in model_checkpoints if 'resnet18' in f],
}

In [4]:
# Some examples from the dataset where our model should ideally show:
idxs = {
    'low_ffm': 190,  # low ffm score (highly confident, incorrect prediction, with consistent explanation for the incorrect prediction)
    'moderate_ffm': 8,  # moderate ffm score (correct prediction, but with incorrect/inconsistent explanation for ground truth)
    'high_ffm': 6,  # high ffm score (highly confident, correct prediction, with consistent explanation for the ground truth)
}

# Add more idxs here if you want to see more examples

In [5]:
# Load model
MODEL = 'resnet18'
model = load_model(MODEL).to('cuda')
if model_checkpoints[MODEL]:
    model.load_state_dict(torch.load(f"checkpoints/{model_checkpoints[MODEL][0]}/model.pth"))
    print(f"Loaded model from {model_checkpoints[MODEL][0]}")
else:
    print("No model checkpoint found, loading pretrained model")

model = model.eval()  # Set model to evaluation mode
target_layer = model.layer4[-1]  # Last layer of ResNet18

# Compute FFM for each example
for key, i in idxs.items():
    batch = tranform_preprocess(test_dataset[i][0]).unsqueeze(0).to('cuda')
    gt_labels = [test_dataset[i][1]]

    # Uncomment the following lines to test on a batch of images
    # batch = torch.vstack([batch]*4)
    # gt_labels = gt_labels*4

    # Compute FFM
    ffm_val = compute_ffm(
        model=model,
        img=batch,
        gt_label=gt_labels,
        target_layer=target_layer,
        reduce='none',
        top_k=5,
        output_softmax=True,
        device='cuda'
    )

    print(f'{i}th FFM for {MODEL}: {ffm_val}')
    print(f'(is {key})')
    print('*'*50)




Loaded model from resnet18-20240917-161535


100%|██████████| 16/16 [00:00<00:00, 47.05it/s]
100%|██████████| 16/16 [00:00<00:00, 130.62it/s]
100%|██████████| 16/16 [00:00<00:00, 60.11it/s]
100%|██████████| 16/16 [00:00<00:00, 140.15it/s]
100%|██████████| 16/16 [00:00<00:00, 49.93it/s]
100%|██████████| 16/16 [00:00<00:00, 44.33it/s]


190th FFM for resnet18: [0.1214497]
(is low_ffm)
**************************************************


100%|██████████| 16/16 [00:00<00:00, 58.92it/s]
100%|██████████| 16/16 [00:00<00:00, 93.34it/s]
100%|██████████| 16/16 [00:00<00:00, 46.25it/s]
100%|██████████| 16/16 [00:00<00:00, 93.92it/s] 
100%|██████████| 16/16 [00:00<00:00, 131.15it/s]
100%|██████████| 16/16 [00:00<00:00, 65.10it/s]


8th FFM for resnet18: [0.5]
(is moderate_ffm)
**************************************************


100%|██████████| 16/16 [00:00<00:00, 74.95it/s]
100%|██████████| 16/16 [00:00<00:00, 84.30it/s]
100%|██████████| 16/16 [00:00<00:00, 78.22it/s]
100%|██████████| 16/16 [00:00<00:00, 85.45it/s]
100%|██████████| 16/16 [00:00<00:00, 132.73it/s]
100%|██████████| 16/16 [00:00<00:00, 78.44it/s]

6th FFM for resnet18: [0.98805539]
(is high_ffm)
**************************************************



