# Introduction

In this tutorial, we will explain the basics of extending Data Detective to new data types, new validator methods, new validators, and new transforms. We will build up to performing anomaly detection on the MNIST dataset using PCA anomaly scoring.

Prerequisites include all of the information in the Data Detective Basics tutorial.

Let's get started!

In [None]:
!pip install --upgrade torchvision
!pip install --upgrade pyod

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyod
import torch
import torchvision.transforms as transforms
import typing

from torchvision.datasets import MNIST
from typing import Dict, Union, Set, Type

from constants import FloatTensor
from src.aggregation.rankings import RankingAggregator, RankingAggregationMethod
from src.data_detective_engine import DataDetectiveEngine
from src.enums.enums import DataType, ValidatorMethodParameter
from src.validators.data_validator import DataValidator


# Dataset Construction

In [None]:
DATASET_SIZE = 50 

class TutorialDataset(MNIST):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        dataset_size = self.__len__()
    
    def __getitem__(self, idx: Union[int, slice, list]) -> Dict[str, Union[FloatTensor, int]]:
        """
        Returns a dictionary of the image, vector, and label. 
        """
        sample = super().__getitem__(idx)
        return {
            "mnist_image": sample[0],
            "label": sample[1],
        }
    
    #TODO: remove
    def __len__(self) -> int: 
        return DATASET_SIZE

    def datatypes(self) -> Dict[str, DataType]:
        return {
            "mnist_image": DataType.IMAGE,
            "label": DataType.CATEGORICAL,
        }
    
    def show_datapoint(self, idx: int):
        """
        Shows data point from tutorial.
        """
        # src: https://stackoverflow.com/questions/31556446/how-to-draw-axis-in-the-middle-of-the-figure
        sample = self[idx]
        print(sample["label"])
        plt.imshow(sample["mnist_image"].squeeze())
        plt.show()
        
    
dataset = TutorialDataset(
    root='./data/MNIST',
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

# Creating a Validator Method

## The Structure of Validator Methods

Each validator method is a static class that has 4 static, functional methods:
1. The `datatype` method, which returns a set of datatypes that are supported by the validator method. 
<!---The set is considered an "OR" relation (that is, if any of of the datatypes in the set are present in the dataset, the validator method will be applied). --->
2. The `param_keys` method, which returns a set containinng the data splits that the method applies to. 
3. The `validate` method, which returns some type of actionable result. 
4. The `get_method_kwargs` method, which takes in the data object and validator kwargs and sets up the calls to `validate`. 

Let's go through a simple validator method and examine all of these components. The example validator method that we will be examining determines the principle components of a multidimensional column over a dataset and uses reconstruction error over the fitted principle components to provide an anomaly score. 

## Example: PCA Anomaly Validator Method

The first step in writing a new validator method is creating a good test for the validator method using synthetic data. Tests are a crucial part of the data detective validator method construction process for three reasons: 
1. They are helpful early on in the design process for considering and enforcing sensible top-down interface decisions.
2. They are a useful piece of documentation to both yourself as you write the method and to an end user in understanding how to use your method.
3. They verify correctness of your implementation. 

For our example, we will be constructing a 10-dimensional synthetic normal dataset with 99% of samples drawn from N(μ=0, σ=1) and 1% of samples drawn from N(μ=10, σ=1). In order to examine correctness, we will look at the AUCRoC scores between the true anomaly labels and the incorrect anomaly labels. The test that we will be using is shown below. [TODO: write test]

### 1. `datatype()` method

We would like our PCA method to take in only multidimensional data, so let's specify that in the `datatype()` method. We specify this by returning a set of `DataType` objects. If we had wanted to start extending support to new datatypes, we would at this point extend the `DataType` enumeration and specify the new data type in the `datatype` method. 

In [None]:
def datatype() -> Set[DataType]:
    """
    @return: the datatype the validators method operates on.
    """
    return { DataType.MULTIDIMENSIONAL }

### 2. `param_keys()` method

PCA Anomaly validation is an unsupervised method, so it needs to take in the entire dataset to fit/evaluate the model on.

In [None]:
def param_keys() -> Set[ValidatorMethodParameter]:
    """
    Lists the data splits that the validators operates on.
    
    @return: a set of data splits for the .validate() method.
    """
    return { ValidatorMethodParameter.ENTIRE_SET }

### 3. `validate()` method

Our `validate()` method will map us from some representation of the data to a single result. For the PCA `validate()` method, let's choose to take in the entire n x d data matrix for a given column of data and an option indicating the number of components to keep for computation of outlier scores. In the method body, we will fit an existing PCA anomaly detection method from `pyod` and use that model to give us a set of anomaly scores based on reconstruction loss.


In [None]:
def validate(
    data_matrix: Type[np.array] = None, # n x d data matrix for a givenn column
    n_components=None,
) -> object:
    """
    Runs PCA anomaly detection.

    @return: a list of n scores, one per sample. 
    """
    model = pyod.models.pca.PCA(n_components=n_components)
    model.fit(data_matrix)

    anomaly_scores = model.decision_function(data_matrix)

    return anomaly_scores

### 4. `get_method_kwargs()` method

In the `get_method_kwargs()` method, we will be taking the set of options passed in the validation schema as well as the data object and setting up the calls to the `validate()` method. This method should return a dictionary where each value contains the kwargs for a `validate()` call and each key reflects where the `validate()` call will store its results in the final method results dictionary. 

Every `get_method_kwargs()` method accepts two things: the validation schema and the data object. For our PCA anomaly example, we will want to perform one call for each entry in the data object, giving us a score for each column of each sample. 

<!--Every `get_method_kwargs()` method accepts two things: the validation schema and the (filtered) data object. The data object is preliminarily filtered in two ways: 
1. The `include` option in the validation schema accepts a list of regular expressions under each 
2. The `datatype()` method results in the data object being filtered to only include columns in that data object.-->


First, we will include a helper method to get the data matrix from a dataset:

In [None]:
def _get_data_matrix_dict(dataset: torch.utils.data.Dataset = None) -> Dict[str, np.array]:
    """
    Takes a dataset and returns a dictionary mapping from each column in the dataset to an n x d 
    numpy array, where n is the number of entries in the dataset and d is the column's dimension
    in the dataset. 
        
    @return: an n x d numpy array, where n is the number of entries in the dataset and d is the 
    column's dimension in the dataset. 
    """
    matrix_dict = {
        column: [] for column in dataset.datatypes().keys()
    }

    for idx in range(dataset.__len__()):
        sample = dataset[idx]
        for column, column_data in sample.items():
            matrix_dict[column].append(column_data)

    for column in dataset.datatypes().keys():
        matrix_dict[column] = np.vstack(matrix_dict[column])
            
    return matrix_dict

Now, let's use the above method to write our `get_method_kwargs()` function, which needs to retrieve our `data_matrix` and `n_components` parameters. 

In [None]:
def get_method_kwargs(data_object: Dict[str, torch.utils.data.Dataset], validator_kwargs: Dict = None) -> Dict:
    """
    Gets the arguments for each run of the validator_method, and what to store the results under.

    @param data_object: the datasets object containing the datasets (train, test, entire, etc.)
    @param validator_kwargs: the kwargs from the validation schema.
    @return: a dict mapping from the key the result from calling .validate() on the kwargs values.
    """
    entire_dataset: torch.utils.data.Dataset = data_object["entire_set"]
    matrix_dict = _get_data_matrix_dict(entire_dataset)
        
    kwargs_dict = {
        f"{column}_results": {
            "data_matrix": column_data,
            "n_components": validator_kwargs.get("n_components")
                            # ^will default to None if n_components is not provided
        } for column, column_data in matrix_dict.items()
    }

    return kwargs_dict

Great! Let's wrap all of the methods we have written in a single class.

In [None]:
import typing
from typing import Set, Dict, Type

import numpy as np
import pyod.models.pca
from torch.utils.data import Dataset

from src.enums.enums import DataType, ValidatorMethodParameter
from src.validator_methods.data_validator_method import DataValidatorMethod

class MyPCAAnomalyValidatorMethod(DataValidatorMethod):
    """
    A method for determining anomalies on multidimensional data. Operates on continuous datasets.
    """
    @staticmethod
    def name() -> str: 
        return "my_pca_validator_method"

    @staticmethod
    def datatype() -> Set[DataType]:
        return datatype()


    @staticmethod
    def param_keys() -> Set[ValidatorMethodParameter]:
        """
        @return: a set of data splits for the data object to include.
        """
        return param_keys()

    @staticmethod
    def get_method_kwargs(data_object: typing.Dict[str, Dataset], validator_kwargs: Dict = None) -> Dict:
        """
        Gets the arguments for each run of the validator_method, and what to store the results under.

        @param data_object: the datasets object containing the datasets (train, test, entire, etc.)
        @param validator_kwargs: the kwargs from the validation schema.
        @return: a dict mapping from the key the result from calling .validate() on the kwargs values.
        """
        return get_method_kwargs(data_object, validator_kwargs)

    @staticmethod
    def validate(
        data_matrix: Type[np.array] = None, # n x d data matrix for a givenn column
        n_components=None,
    ) -> object:
        """
        Runs PCA anomaly detection.

        @return: a list of n scores, one per sample. 
        """
        return validate(data_matrix, n_components)

# Creating a Validator

Validators are simply sets of validator methods. Creating a new one is relatively straightforward. They consist of the set of validator methods that they include as well as a `default` attribute indicating whether the validator should be included in all validation schemas. Let's create a non-default validator for our PCA anomaly method.

In [None]:
class MyUnsupervisedAnomalyDataValidator(DataValidator):
    """
    A dataset has many features/columns, and each column has many ValidatorMethods that apply to it, depending on the
    datatype. A DataValidator is a collection of ValidatorMethods for a unique purpose.
    """
    @staticmethod
    def name() -> str: 
        return "my_unsupervised_anomaly_data_validator"
    
    @staticmethod
    def is_default():
        return False

    @staticmethod
    def validator_methods() -> Set[Type[DataValidatorMethod]]:
        return {
            MyPCAAnomalyValidatorMethod
        }

# Creating a Transform

There are two steps to creating a new transform: 

1. Creating a higher order transformation function
2. Registering the new transform in the transform library. 

Let's look at an example of a simple transform that maps images to their ResNet50 embeddings.

In [None]:
from src.transforms.embedding_transformer import Transform


def get_resnet50(**kwargs):
    import torchvision.models
    resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2, **kwargs)
    modules = list(resnet.children())[:-1]
    backbone = torch.nn.Sequential(torch.nn.Upsample((224, 224)), *modules)
    def full_impl(x):
        if len(x.shape) == 3:
            # need a 4th dimension
            x = torch.unsqueeze(x, 0)
        if x.shape[-3] == 1:
            # need to map to multiple channels
            x2 = torch.zeros((x.shape[0], 3, x.shape[2], x.shape[3]))
            x2[:, 0, :, :] = x[:, 0, :, :]
            x2[:, 1, :, :] = x[:, 0, :, :]
            x2[:, 2, :, :] = x[:, 0, :, :]
            x = x2

        x = backbone(x)
        x = x.squeeze()
        x = x.reshape((-1, 2048))
        x = x.detach().numpy()

        return x

    return full_impl

my_resnet50_transform = Transform(
        transform_class=get_resnet50,
        new_column_name_fn=lambda name: f"resnet50_backbone_{name}",
        new_column_datatype=DataType.MULTIDIMENSIONAL
    )

There are a few patterns worth noting in the above implementation. The `get_{transform}` higher order function always takes in kwargs that are passed through from the `options` parameter of the input transforms. 

The most important is the use of an inner helper function (in this case, `full_impl`) that is returned. Returning an inner function allows for one-time initialization of the backbone and of the parsing of options in kwargs. 

Now that we have our higher order transformation function, let's add it to the transform library.

In [None]:
from src.transforms.embedding_transformer import Transform

TRANSFORM_LIBRARY = {
    # our new lovely transform!
    "resnet50": Transform(
        transform_class=get_resnet50,
        new_column_name_fn=lambda name: f"resnet50_backbone_{name}",
        new_column_datatype=DataType.MULTIDIMENSIONAL
    ),
}

Great! We have successfully added the transform to the transform library. 

In [None]:
# todo: this is what we want the implementation of the data_detective_engine to look like. 
# we need to parse this into two tutorials: one for contributions to the existing codebase, and one for modular extensions.

data_detective_engine = DataDetectiveEngine()

data_detective_engine.register_validator(MyUnsupervisedAnomalyDataValidator)
data_detective_engine.register_transform(my_resnet50_transform, "my_resnet50")

validation_schema = {
    "default_inclusion": False,
    "validators": {
        "my_unsupervised_anomaly_data_validator": {},
    },
    "transforms": {
        "image": [{
            "name": "my_resnet50",
            "in_place": "False",
            "options": {},
        }],
    }
}

data_object = {
    "entire_set": dataset
}

# Running the Data Detective Engine

Now that the full validation schema and data object are prepared, we are ready to run the Data Detective Engine.

In [None]:
results = data_detective_engine.validate_from_schema(validation_schema, data_object)

In [None]:
results

Great! Let's start to look at and analyze the results we've collected.

# Interpreting Results using the Built-In Rank Aggregator

In [None]:
from enum import Enum

import pandas as pd
import scipy
from typing import List

from pyrankagg.rankagg import FullListRankAggregator

class RankingAggregationMethod(Enum):
    MEDIAN_AGGREGATION = "median_aggregation"
    HIGHEST_RANK = "highest_rank"
    LOWEST_RANK = "lowest_rank"
    STABILITY_SELECTION = "stability_selection"
    EXPONENTIAL_WEIGHTING = "exponential_weighting"
    STABILITY_ENHANCED_BORDA = "stability_enhanced_borda"
    EXPONENTIAL_ENHANCED_BORDA = "exponential_enhanced_borda"
    ROBUST_AGGREGATION = "robust_aggregation"
    ROUND_ROBIN = "round_robin"


class RankingAggregator:
    FLRA = FullListRankAggregator()

    def __init__(self, results_object):
        self.results_object = results_object

    @staticmethod
    def list_is_full_ranking(lst):
        lst_len = len(lst)
        return list(range(lst_len)) == sorted(lst)

    @staticmethod
    def convert_to_scorelist(dataframe):
        """ scorelist = [{'milk':1.4,'cheese':2.6,'eggs':1.2,'bread':3.0},
                         {'milk':2.0,'cheese':3.2,'eggs':2.7,'bread':2.9},
                         {'milk':2.7,'cheese':3.0,'eggs':2.5,'bread':3.5}]"""
        scorelist = []
        for col in dataframe.columns:
            tmp_dict = {f"item {idx}": val for idx, val in zip(dataframe.index, dataframe[col])}
            scorelist.append(tmp_dict)
        return scorelist

    @staticmethod
    def get_rankings(scores):
        return {f"item {k}": v for k, v in RankingAggregator.FLRA.convert_to_ranks(dict(enumerate(scores))).items()}

    def construct_rankings_df(self, validator_name, given_validator_method: str = None, given_data_modality: str = None):
        validator_results = self.results_object[validator_name]
        results_obj = {}

        for validator_method, results_dict in validator_results.items():
            if given_validator_method and (validator_method != given_validator_method):
                continue
            for data_modality, scores in results_dict.items():
                if given_data_modality and (data_modality.replace("_results", "") != given_data_modality):
                    continue
                rankings = RankingAggregator.get_rankings(scores)
                results_obj[f"{data_modality}_{validator_method}_rank"] = rankings
        
        rankings_df = pd.DataFrame(results_obj)
        return rankings_df.sort_index()

    def aggregate_modal_rankings(self, validator_name: str, aggregation_methods: List[RankingAggregationMethod], given_data_modality: str = None, invert=False): 
        rankings_df = self.construct_rankings_df(validator_name, given_data_modality=given_data_modality)
        output_df = rankings_df.copy()
        
        for aggregation_method in aggregation_methods:
            aggregation_method_name = aggregation_method.value
            scorelist = self.convert_to_scorelist(rankings_df)
            agg_method = getattr(RankingAggregator.FLRA, aggregation_method_name)
            agg_rankings = agg_method(scorelist)[1]
            output_df[f"{aggregation_method_name}_agg_rank"] = list(agg_rankings.values())

        return output_df
        
    def aggregate_rankings(self, validator_name: str, aggregation_methods: List[RankingAggregationMethod]):
        rankings_df = self.construct_rankings_df(validator_name)
        output_df = rankings_df.copy()
        
        for aggregation_method in aggregation_methods:
            aggregation_method_name = aggregation_method.value
            scorelist = self.convert_to_scorelist(rankings_df)
            agg_method = getattr(RankingAggregator.FLRA, aggregation_method_name)
            agg_rankings = agg_method(scorelist)[1]

            output_df[f"{aggregation_method_name}_agg_rank"] = list(agg_rankings.values()) 

        return output_df

aggregator = RankingAggregator(results_object=results)
input_df = aggregator.aggregate_modal_rankings("my_unsupervised_anomaly_data_validator", [RankingAggregationMethod.LOWEST_RANK, RankingAggregationMethod.HIGHEST_RANK, RankingAggregationMethod.ROUND_ROBIN], given_data_modality="resnet50_backbone_mnist_image")
input_df

In [None]:
dataset.show_datapoint(0)