Process inputs for different types of explainers
* pixel-based (of normalized images)
    * RGB ((3 x H x W)-dimension vectors)
    * intensity ((H x W)-dimension vectors
* SIFT feature based
    * set parameters: 
        * sigma : scaling parameter (smaller is smaller Gaussian blurs)
        * K : number of clusters to form histogram
* concept-based
    * labelled
        * set parameters:
            * tau : occurrence threshold (minimum number of images in the training set that concept must occur in)
    * discovery
        * currently not implemented

In [1]:
import os, sys
import numpy as np
import torch
from tqdm import tqdm

sys.path.insert(0, 'src')
import datasets.datasets as module_data
from utils.utils import ensure_dir

In [2]:
DATASETS_AVAILABLE = ['cifar', 'ADE20K']
INPUT_TYPES_AVAILABLE = ['pixel', 'SIFT', 'concept']
PIXEL_TYPES_AVAILBLE = ['RGB', 'intensity']
SAVE_ROOT = 'data/explainer_inputs'

## Set variables

In [3]:
dataset_type = 'cifar'
input_type = 'SIFT'

# Params for pixel-based explainers
pixel_type = 'intensity'

# Params for SIFT-based explainers
sigma = 1.6
K = 500
stride = 2
SIFT_descriptors_path = 'saved/cifar10/sift_32_32_sigma1.6/dense_stride_2/sift_keypoints_descriptors.pth'
SIFT_KMeans_path = 'saved/cifar10/sift_32_32_sigma1.6/dense_stride_2/minibatch_kmeans/500means/descriptor_kmeans.pth'

save_dir = os.path.join(SAVE_ROOT, dataset_type, input_type)
# Check inputs are valid
assert dataset_type in DATASETS_AVAILABLE, "Dataset type '{}' not supported. Try one of {}".format(dataset_type, DATASET_TYPES_AVAILABLE)
assert input_type in INPUT_TYPES_AVAILABLE

if input_type == 'pixel':
    print("Pixel-based explainer using {} inputs".format(pixel_type))
    save_dir = os.path.join(save_dir, pixel_type)
elif input_type == 'SIFT':
    assert str(sigma) in SIFT_descriptors_path and str(sigma) in SIFT_KMeans_path
    assert str(K) in SIFT_KMeans_path
    assert str(stride) in SIFT_descriptors_path and str(stride) in SIFT_KMeans_path
    print("SIFT-based explainer using sigma={}; K={}; stride={}".format(sigma, K, stride))
    print("Loading descriptors from {}".format(SIFT_descriptors_path))
    print("Loading KMeans from {}".format(SIFT_KMeans_path))
    save_dir = os.path.join(save_dir, 'sigma_{}'.format(sigma), 'stride_{}'.format(stride), 'K_{}'.format(K))
else:
    raise ValueError("Explainer type '{}' not supported. Try one of {}".format(
        input_type, INPUT_TYPES_AVAILABLE))

# Create save directory
ensure_dir(save_dir)
print("Created save directory at {}".format(save_dir))

# Obtain paths for data to be processed
if dataset_type == 'cifar':
    dataset_dir = os.path.join('data', 'cifar10-processed')
    if input_type == 'pixel':
        paths = {
            'dataset_dir': dataset_dir
        }
    elif input_type == 'SIFT':
        print("Loading descriptors and KMeans...")
        SIFT_descriptors = torch.load(SIFT_descriptors_path)
        SIFT_KMeans = torch.load(SIFT_KMeans_path)
        # paths = {
        #     'descriptor_path': SIFT_descriptors_path,
        #     'KMeans_path': SIFT_KMeans_path
        # }
    IMAGE_HEIGHT = 32
    IMAGE_WIDTH = 32
    N_TRAIN = 50000
    N_TEST = 10000
else: 
    raise ValueError("Dataset type '{}' not supported. Try one of {}".format(dataset_type, DATASET_TYPES_AVAILABLE))

SIFT-based explainer using sigma=1.6; K=500; stride=2
Loading descriptors from saved/cifar10/sift_32_32_sigma1.6/dense_stride_2/sift_keypoints_descriptors.pth
Loading KMeans from saved/cifar10/sift_32_32_sigma1.6/dense_stride_2/minibatch_kmeans/500means/descriptor_kmeans.pth
Created save directory at data/explainer_inputs/cifar/SIFT/sigma_1.6/stride_2/K_500
Loading descriptors and KMeans...


## Process the inputs

### Define functions

In [59]:
def get_cifar_dataloaders(dataset_dir,
                          normalize,
                          batch_size,
                          num_workers):
    train_dataset = module_data.CIFAR10TorchDataset(
        dataset_dir=dataset_dir,
        split='train',
        normalize=normalize)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=False,
        batch_size=batch_size,
        num_workers=num_workers)

    test_dataset = module_data.CIFAR10TorchDataset(
        dataset_dir=dataset_dir,
        split='test',
        normalize=normalize)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        shuffle=False,
        batch_size=batch_size,
        num_workers=num_workers)
    return train_dataloader, test_dataloader

def flatten_images(train_dataloader,
                   test_dataloader,
                   pixel_type):
    '''
    Arg(s):
        train_dataloader : torch.utils.data.DataLoader
            Dataloader for training dataset
        test_dataloader : torch.utils.data.DataLoader
            Dataloader for testing dataset
        pixel_type : str
            'RGB' or 'intensity'
    Returns:
        tuple(np.array, np.array) : flattened inputs for train and test respectively
    '''
    assert pixel_type in PIXEL_TYPES_AVAILBLE
    
    splits = [('train', train_dataloader), ('test', test_dataloader)]
    flattened_data = {}
    # Iterate through splits
    for split_name, split_dataloader in splits:
        print("Flattening {} data based on {}".format(split_name, pixel_type))
        split_flattened = []
        # Iterate through batches
        for image, label in tqdm(split_dataloader):
            if pixel_type == 'intensity':
                image = torch.mean(image, dim=1) # RGB Channel
            flattened_image = torch.flatten(image, start_dim=1)
            split_flattened.append(flattened_image)
        # Concatenate and convert to numpy arrays
        split_flattened = torch.cat(split_flattened, dim=0)
        split_flattened = split_flattened.numpy()
        flattened_data[split_name] = split_flattened
    
    return flattened_data['train'], flattened_data['test']

def process_pixel_inputs(dataset_type,
                         dataset_dir,
                         pixel_type,
                         save_path):
    '''
    Given a dataset, and a pixel type ('RGB' or 'intensity') flatten images in each split and return as 1D vectors
    
    Arg(s):
        dataset_type : str
            Dataset name. One of ['cifar', 'ADE20K']
        paths : dict{str : str}
            Path to directory where images are stored
        pixel_type : str
            'RGB' or 'intensity'
        save_path : str or None
            
    '''
    batch_size = 256
    num_workers = 8
    normalize = True
    
    # Assert file doesn't already exist at save_path
    assert not os.path.exists(save_path), "File already exists at {}. Please remove it".format(save_path)
    
    if dataset_type == 'cifar':
        train_dataloader, test_dataloader = get_cifar_dataloaders(
            dataset_dir=dataset_dir,
            normalize=normalize,
            batch_size=batch_size,
            num_workers=num_workers)
        train_flattened, test_flattened = flatten_images(
            train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            pixel_type=pixel_type)
        
        # Perform checks on shape given CIFAR images are 32 x 32
        n_train, train_dim = train_flattened.shape
        n_test, test_dim = test_flattened.shape
        if pixel_type == 'RGB':
            assert train_dim == 3 * IMAGE_HEIGHT * IMAGE_WIDTH, \
                "Expected training data to have {}-dims. Received {}.".format(
                3 * IMAGE_HEIGHT * IMAGE_WIDTH, train_flattened.shape[1])
            assert test_dim == 3 * IMAGE_HEIGHT * IMAGE_WIDTH, \
                "Expected test data to have {}-dims. Received {}.".format(
                3 * IMAGE_HEIGHT * IMAGE_WIDTH, test_flattened.shape[1])
            
        elif pixel_type == 'intensity':
            assert train_dim == IMAGE_HEIGHT * IMAGE_WIDTH, \
                "Expected training data to have {}-dims. Received {}.".format(
                IMAGE_HEIGHT * IMAGE_WIDTH, train_flattened.shape[1])
            assert test_dim == IMAGE_HEIGHT * IMAGE_WIDTH, \
                "Expected test data to have {}-dims. Received {}.".format(
                IMAGE_HEIGHT * IMAGE_WIDTH, test_flattened.shape[1])
        # Check number of samples in train/test
        assert n_train == N_TRAIN and n_test == N_TEST
        
        # Save file
        explainer_inputs = {
            'train': train_flattened,
            'test': test_flattened
        }
        torch.save(explainer_inputs, save_path)
        print("Saved {} flattened cifar images to {}".format(pixel_type, save_path))
    else:
        raise ValueError("Dataset type '{}' not supported. Try one of {}".format(dataset_type, DATASET_TYPES_AVAILABLE))
    return train_flattened, test_flattened



In [17]:
def process_SIFT_inputs(dataset_type,
                        descriptors,
                        KMeans,
                        save_path):
    '''
    Given paths to descriptors and the KMeans clustering, create histograms for each image in train/test
    
    '''
    # Assert file doesn't already exist at save_path
    assert not os.path.exists(save_path), "File already exists at {}. Please remove it".format(save_path)
    
    assert 'train' in descriptors.keys(), "Expected key 'train' in descriptors. Only found {}".format(descriptors.keys())
    assert 'test' in descriptors.keys(), "Expected key 'test' in descriptors. Only found {}".format(descriptors.keys())
    
    train_descriptors = descriptors['train']['descriptors']
    test_descriptors = descriptors['test']['descriptors']
    
    K = KMeans.cluster_centers_.shape[0]
    splits = [('train', train_descriptors), ('test', test_descriptors)]
    histogram_vectors = {}
    
    for split_name, split_descriptors in splits:
        split_histogram_vectors = []
        for idx, image_descriptors in enumerate(tqdm(split_descriptors, total=len(split_descriptors))):

            n_descriptors = len(image_descriptors)
            descriptor_clusters = KMeans.predict(image_descriptors)
            histogram = np.zeros(K)
            for cluster_idx in descriptor_clusters:
                histogram[cluster_idx] += 1 / n_descriptors  # add 1/n_descriptors bc histogram will be normalized

            split_histogram_vectors.append(histogram)
        # Concatenate histograms to np array
        split_histogram_vectors = np.stack(split_histogram_vectors, axis=0)

        histogram_vectors[split_name] = split_histogram_vectors
    torch.save(histogram_vectors, save_path)
    print("Saved histogram of SIFT features from cifar images to {}".format(save_path))
    return histogram_vectors['train'], histogram_vectors['test']
    

### Call process inputs

In [18]:
filename = '{}_{}'.format(dataset_type, input_type)
if input_type == 'pixel':
    filename = '{}_{}'.format(filename, pixel_type)
elif input_type == 'SIFT':
    filename = '{}_sigma_{}_stride_{}_K_{}'.format(filename, sigma, stride, K)
else:
    raise ValueError("Explainer type '{}' not supported. Try one of {}".format(
        input_type, INPUT_TYPES_AVAILABLE))
filename = '{}_explainer_inputs.pth'.format(filename)
save_path = os.path.join(save_dir, filename)

if input_type == 'pixel':
    train_flattened, test_flattened = process_pixel_inputs(
        dataset_type=dataset_type,
        paths=paths,
        pixel_type=pixel_type,
        save_path=save_path)
elif input_type == 'SIFT':
    train_histograms, test_histograms = process_SIFT_inputs(
        dataset_type=dataset_type,
        descriptors=SIFT_descriptors,
        KMeans=SIFT_KMeans,
        save_path=save_path)

100%|████████████████████████████████████| 50000/50000 [00:53<00:00, 927.53it/s]
100%|████████████████████████████████████| 10000/10000 [00:10<00:00, 931.74it/s]


Saved histogram of SIFT features from cifar images to data/explainer_inputs/cifar/SIFT/sigma_1.6/stride_2/K_500/cifar_SIFT_sigma_1.6_stride_2_K_500_explainer_inputs.pth
