In [5]:
# Load library
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(''))))
sys.path.append(os.path.join(PROJ_DIR,'src'))

import xai_faithfulness_experiments_lib_edits as ff
import numpy as np

In [6]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
import PIL
batch_size = 256
IMAGENETTE_PATH = os.path.join(PROJ_DIR, 'data', 'imagenette')
IMAGENETTE_CLASS_DICT = {'n01440764':0, 'n02102040':217, 'n02979186':481, 'n03000684':491, 'n03028079':497, 'n03394916':566, 'n03417042':569, 'n03425413':571, 'n03445777':574, 'n03888257':701}
IMAGENETTE_CLASS_DIRS = sorted(list(IMAGENETTE_CLASS_DICT.keys()))

def get_imagenette_dataset(is_test=False, project_path:str='../'):
    ''' Loads the imagenette dataset. By default it loads the train partition, unless otherwise indicated'''
    def transform_labels(l):
        print('l', l)
        new_l = IMAGENETTE_CLASS_DICT[IMAGENETTE_CLASS_DIRS[l]]
        print('new_l', new_l)
        return new_l

    def load_sample(path: str) -> dict:
        """Read data as image and path. """
        return PIL.Image.open(path).convert("RGB")


    DATA_TRAIN_PATH = os.path.join(project_path, IMAGENETTE_PATH, 'train')
    DATA_TEST_PATH = os.path.join(project_path, IMAGENETTE_PATH, 'val')

    transform = transforms.Compose([
                    transforms.Resize(256), 
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ])
    # Load test data and make loaders.
    dataset = torchvision.datasets.DatasetFolder(DATA_TEST_PATH if is_test else DATA_TRAIN_PATH, 
                                                loader=load_sample, 
                                                is_valid_file=lambda path: path[-5:]==".JPEG",
                                                transform=transform, # Should we do this here or work with the full images for the RL process??
                                                target_transform=transform_labels)
    return dataset

def get_imagenette_train_loader(batch_size:int = 24, project_path:str='../') -> torch.utils.data.DataLoader:
    dataset = get_imagenette_dataset(project_path=project_path)
    train_loader = torch.utils.data.DataLoader(dataset, shuffle=False, batch_size=batch_size)
    return train_loader

def get_imagenette_test_loader(batch_size:int = 24, project_path:str='../') -> torch.utils.data.DataLoader:
    dataset = get_imagenette_dataset(True, project_path=project_path)
    test_loader = torch.utils.data.DataLoader(dataset, shuffle=False, batch_size=batch_size)
    return test_loader

train_loader = get_imagenette_train_loader(52, PROJ_DIR)

examples = enumerate(train_loader)
batch_idx, (x_train, y_train) = next(examples)

MODEL_NAME = 'resnet50w'
# Load model
class ResNet50Wrapper(torch.nn.Module):
    def __init__(self, weights='DEFAULT', device='cpu'):
        super(ResNet50Wrapper, self).__init__()
        # Load the pre-trained ResNet50 model
        self.resnet50 = torchvision.models.resnet50(weights=weights).to(device)
        # Set the model to evaluation mode
        self.resnet50.eval()

    def forward(self, x):
        # Forward pass through the pre-trained ResNet50
        logits = self.resnet50(x)
        # Apply softmax to convert logits to probabilities
        probabilities = F.softmax(logits, dim=1)
        return probabilities

# Use the wrapper
network = ResNet50Wrapper(weights="DEFAULT", device=device).eval()

Using cuda:0
l 8
new_l 574
l 2
new_l 481
l 8
new_l 574
l 5
new_l 566
l 2
new_l 481
l 1
new_l 217
l 2
new_l 481
l 8
new_l 574
l 9
new_l 701
l 8
new_l 574
l 7
new_l 571
l 6
new_l 569
l 0
new_l 0
l 2
new_l 481
l 3
new_l 491
l 5
new_l 566
l 2
new_l 481
l 3
new_l 491
l 5
new_l 566
l 0
new_l 0
l 0
new_l 0
l 1
new_l 217
l 9
new_l 701
l 8
new_l 574
l 5
new_l 566
l 5
new_l 566
l 6
new_l 569
l 5
new_l 566
l 0
new_l 0
l 4
new_l 497
l 1
new_l 217
l 5
new_l 566
l 5
new_l 566
l 4
new_l 497
l 7
new_l 571
l 1
new_l 217
l 9
new_l 701
l 1
new_l 217
l 4
new_l 497
l 6
new_l 569
l 7
new_l 571
l 2
new_l 481
l 2
new_l 481
l 4
new_l 497
l 0
new_l 0
l 7
new_l 571
l 9
new_l 701
l 8
new_l 574
l 4
new_l 497
l 0
new_l 0
l 9
new_l 701
l 0
new_l 0


In [10]:
SAMPLE_NUM = 10
print('Processing', SAMPLE_NUM)
row = x_train[SAMPLE_NUM].clone().detach().to(device)
label = y_train[SAMPLE_NUM].clone().detach().to(device)

print(network(row.unsqueeze(dim=0)).argmax(), network(row.unsqueeze(dim=0)).max())
print(label)

Processing 10
tensor(571, device='cuda:0') tensor(0.1561, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(571, device='cuda:0')
