# Imports and Helper Functions

In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import os
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from ffm import compute_ffm

In [2]:
# Grayscale to RGB transform
class GrayscaleToRGB(object):
    """From https://www.kaggle.com/code/cafalena/caltech101-pytorch-deep-learning"""
    def __call__(self, img):
        if img.mode == 'L':
            img = img.convert("RGB")
        return img

transform = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),]
    )

tranform_preprocess = transforms.Compose(
    [
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )

transform_all = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )

# Load Caltech101 dataset
def load_data(root='caltech_data', transform=transform):
    dataset = datasets.Caltech101(root='caltech_data', download=False, transform=transform)

    return dataset

# Split dataset into training and testing
def split_data(dataset):
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))
    return train_dataset, test_dataset

# Load pretrained ResNet model
def load_model(model_name: str):
    model = getattr(models, model_name)(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 101)
    return model

# FFM Analysis

In [None]:
# Load data
dataset = load_data(transform=transform)
train_dataset, test_dataset = split_data(dataset)

model_checkpoints = os.listdir("checkpoints")
model_checkpoints = {
    'resnet18': [f for f in model_checkpoints if 'resnet18' in f],
    'resnet34': [f for f in model_checkpoints if 'resnet34' in f],
    'resnet50': [f for f in model_checkpoints if 'resnet50' in f],
    'resnet101': [f for f in model_checkpoints if 'resnet101' in f],
}

In [87]:
i = 8  # Index of the image to be used for FFM computation
batch = tranform_preprocess(test_dataset[i][0]).unsqueeze(0).to('cuda')
gt_labels = [test_dataset[i][1]]

# Load model
for MODEL in ['resnet18']: #, 'resnet34', 'resnet50', 'resnet101']:
    model = load_model(MODEL).to('cuda')
    if model_checkpoints[MODEL]:
        model.load_state_dict(torch.load(f"checkpoints/{model_checkpoints[MODEL][0]}/model.pth"))
        print(f"Loaded model from {model_checkpoints[MODEL][0]}")
    else:
        print("No model checkpoint found, loading pretrained model")

    model = model.eval()
    target_layer = model.layer4[-1]

    # Compute FFM
    ffm_val = compute_ffm(
        model=model,
        img=batch,
        gt_label=gt_labels,
        target_layer=target_layer,
        reduce='none',
        top_k=5,
        device='cuda'
    )

    print(f'{i}th FFM for {MODEL}: {ffm_val}')
    print('*'*50)




Loaded model from resnet18-20240917-161535


100%|██████████| 16/16 [00:00<00:00, 122.36it/s]
100%|██████████| 16/16 [00:00<00:00, 132.08it/s]
100%|██████████| 16/16 [00:00<00:00, 132.43it/s]
100%|██████████| 16/16 [00:00<00:00, 134.89it/s]
100%|██████████| 16/16 [00:00<00:00, 131.88it/s]
100%|██████████| 16/16 [00:00<00:00, 127.42it/s]

[0.46576792] (1,) (1, 5)
8th FFM for resnet18: [0.5]
**************************************************



