In [None]:
import sys
import time


sys.path.append('..')
import os
import torch
import pandas as pd
import numpy as np
from torch.nn import CrossEntropyLoss
from tqdm import trange

from lens.models.relu_nn import XReluNN
from lens.models.logic import XLogicNN
from lens.models.ent_nn import XEntNN
from lens.models.psi_nn import XPsiNetwork
from lens.models.tree import XDecisionTreeClassifier
from lens.models.brl import XBRLClassifier
from lens.models.logistic_regression import XLogisticRegressionClassifier
from lens.models.deep_red import XDeepRedClassifier
from lens.utils.base import set_seed, ClassifierNotTrainedError, IncompatibleClassifierError
from lens.utils.metrics import Accuracy, F1Score
from lens.models.mu_nn import XMuNN
from lens.utils.datasets import ConceptToTaskDataset
from lens.utils.data import get_splits_train_val_test
from lens.logic.eval import test_explanation
from lens.logic.metrics import complexity, fidelity, formula_consistency
from lens.models.random_forest import XRandomForestClassifier
from lens.utils.datasets import StructuredDataset

results_dir = 'results/cub_pipnet'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

## Loading CUB data

In [None]:
dataset_root = "../data/cub_pipnet/"
X_trainval = np.load(dataset_root + 'cub_pipnet_X_train.npy')
y_trainval = np.load(dataset_root + 'cub_pipnet_y_train.npy')
X_test = np.load(dataset_root + 'cub_pipnet_X_test.npy')
y_test = np.load(dataset_root + 'cub_pipnet_y_test.npy')

In [None]:
import torch
torch.cuda.is_available()

## Extracting concepts from images

In [None]:
concept_names = [f'prototype_{p}' for p in range(X_trainval.shape[1])]
print("Concept names", concept_names)
n_features = X_trainval.shape[1]
print("Number of features", n_features)
class_names = list(range(y_trainval.max()+1))
print("Class names", class_names)
n_classes = len(class_names)
print("Number of classes", n_classes)

## Define loss, metrics and methods

In [None]:
loss = CrossEntropyLoss()
metric = Accuracy()
expl_metric = F1Score()
method_list = ['Ent']  # 'DeepRed']
print("Methods", method_list)

## Setting training hyperparameters

In [None]:
epochs = 1000
n_processes = 1
timeout = 6 * 60 * 60  # 6 h timeout
l_r = 1e-3
lr_scheduler = False
top_k_explanations = None
simplify = True
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
print("Device", device)

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
dataset_name = 'cub_pipnet'

## Training

In [None]:
for method in method_list:

    methods = []
    splits = []
    model_explanations = []
    model_accuracies = []
    explanation_accuracies = []
    elapsed_times = []
    explanation_fidelities = []
    explanation_complexities = []

    seeds = [*range(6)] if method != "BRL" else [*range(15)]
    print("Seeds", seeds)

    for seed in seeds:
        set_seed(seed)
        name = os.path.join(results_dir, f"{method}_{seed}")

        X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, random_state=seed, shuffle=True)
        

        train_data = StructuredDataset(X_train, y_train, dataset_name, concept_names, class_names)
        val_data = StructuredDataset(X_val, y_val, dataset_name, concept_names, class_names)
        test_data = StructuredDataset(X_test, y_test, dataset_name, concept_names, class_names)
        X_train = torch.tensor(X_train)
        X_val = torch.tensor(X_val)
        X_test = torch.tensor(X_test)
        y_train = torch.tensor(y_train)
        y_val = torch.tensor(y_val)
        y_test = torch.tensor(y_test)
        # Setting device
        print(f"Training {name} classifier...")
        start_time = time.time()

        if method == 'DTree':
            max_depth = 30
            name += f"_{max_depth}"
            model = XDecisionTreeClassifier(name=name, n_classes=n_classes,
                                            n_features=n_features, max_depth=max_depth)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, metric=metric, save=True)
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(i, concept_names)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test, metric=expl_metric,
                                                                 concept_names=concept_names, inequalities=True)
                exp_fidelity = 100
                explanation_complexity = complexity(explanation)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)

        elif method == 'BRL':
            train_sample_rate = 1.0
            model = XBRLClassifier(name=name, n_classes=n_classes, n_features=n_features, n_processes=n_processes,
                                   feature_names=concept_names, class_names=class_names, discretize=True)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, metric=metric, train_sample_rate=train_sample_rate, verbose=False, eval=False)
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(i, concept_names)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test, metric=expl_metric,
                                                                 concept_names=concept_names)
                exp_fidelity = 100
                explanation_complexity = complexity(explanation, to_dnf=True)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)

        elif method == 'DeepRed':
            train_idx = train_data.indices
            test_idx = test_data.indices
            train_sample_rate = 0.05
            model = XDeepRedClassifier(n_classes, n_features, name=name)
            model.prepare_data(dataset, dataset_name, seed, train_idx, test_idx, train_sample_rate)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(epochs=epochs, seed=seed, metric=metric)
            outputs, labels = model.predict(train=False, device=device)
            accuracy = model.evaluate(train=False, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            print("Extracting rules...")
            t = time.time()
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(i, concept_names, simplify=simplify)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test,
                                                                 metric=expl_metric,
                                                                 concept_names=concept_names, inequalities=True)
                exp_predictions = torch.as_tensor(exp_predictions)
                class_output = torch.as_tensor(outputs.argmax(dim=1) == i)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                explanation_complexity = complexity(explanation)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)
                print(f"{i + 1}/{len(dataset.classes)} Rules extracted. Time {time.time() - t}")
            # To restore the original folder

        elif method == 'Psi':
            # Network structures
            l1_weight = 1e-4
            hidden_neurons = [10]  # [50, 20, 10]
            fan_in = 4
            lr_psi = 1e-2
            print("L1 weight", l1_weight)
            print("Hidden neurons", hidden_neurons)
            print("Fan in", fan_in)
            model = XPsiNetwork(n_classes, n_features, hidden_neurons, loss, l1_weight, name=name, fan_in=fan_in)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, epochs=epochs, l_r=lr_psi, verbose=True,
                          metric=metric, lr_scheduler=lr_scheduler, device=device, save=True)
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes):
                explanation = model.get_global_explanation(i, concept_names, simplify=simplify, X_train=X_train)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test,
                                                                 metric=expl_metric, concept_names=concept_names)
                exp_predictions = torch.as_tensor(exp_predictions)
                class_output = torch.as_tensor(outputs.argmax(dim=1) == i)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                explanation_complexity = complexity(explanation, to_dnf=True)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)

        elif method == 'General':
            # Network structures
            l1_weight = 1e-4
            hidden_neurons = [20]
            print("L1 weight", l1_weight)
            print("Hidden neurons", hidden_neurons)
            model = XMuNN(n_classes=n_classes, n_features=n_features, hidden_neurons=hidden_neurons,
                               loss=loss, name=name, l1_weight=l1_weight, fan_in=10, )
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, epochs=epochs, l_r=l_r, metric=metric,
                          lr_scheduler=lr_scheduler, device=device, save=True, verbose=True)
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(X_train, y_train, i, top_k_explanations=top_k_explanations,
                                                           concept_names=concept_names, simplify=simplify,
                                                           metric=expl_metric, x_val=X_val, y_val=y_val)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test,
                                                                 metric=expl_metric, concept_names=concept_names)
                exp_predictions = torch.as_tensor(exp_predictions)
                class_output = torch.as_tensor(outputs.argmax(dim=1) == i)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                explanation_complexity = complexity(explanation)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)

        elif method == 'Relu':
            # Network structures
            l1_weight = 1e-7
            hidden_neurons = [300, 200]
            dropout_rate = 0.
            print("l1 weight", l1_weight)
            print("hidden neurons", hidden_neurons)
            model = XReluNN(n_classes=n_classes, n_features=n_features, name=name, dropout_rate=dropout_rate,
                            hidden_neurons=hidden_neurons, loss=loss, l1_weight=l1_weight)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, epochs=epochs, l_r=l_r, verbose=True,
                          metric=metric, lr_scheduler=lr_scheduler, device=device, save=True)
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(X_train, y_train, i,
                                                           top_k_explanations=top_k_explanations,
                                                           concept_names=concept_names,
                                                           metric=expl_metric, x_val=X_val, y_val=y_val)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test,
                                                                 metric=expl_metric, concept_names=concept_names)
                exp_predictions = torch.as_tensor(exp_predictions)
                class_output = torch.as_tensor(outputs.argmax(dim=1) == i)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                explanation_complexity = complexity(explanation)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)
        elif method == 'Ent':
            # Network structures
            l1_weight = 1e-2
            hidden_neurons = [10]
            dropout_rate = 0.
            # epochs=100
            print("l1 weight", l1_weight)
            print("hidden neurons", hidden_neurons)
            model = XEntNN(n_classes=n_classes, n_features=n_features, name=name, dropout_rate=dropout_rate,
                            hidden_neurons=hidden_neurons, loss=loss, l1_weight=l1_weight, temperature=0.7)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, epochs=epochs, l_r=l_r, verbose=True,
                          metric=metric, lr_scheduler=lr_scheduler, device=device, save=True)
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(X_train, y_train, i,
                                                           top_k_explanations=20,
                                                           concept_names=concept_names,
                                                           metric=expl_metric, x_val=X_val, y_val=y_val)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test,
                                                                 metric=expl_metric, concept_names=concept_names)
                exp_predictions = torch.as_tensor(exp_predictions)
                class_output = torch.as_tensor(outputs.argmax(dim=1) == i)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                explanation_complexity = complexity(explanation)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)
                
        elif method == 'Logic':
            # Network structures
            l1_weight = 1e-5
            epochs=1000
            l_r = 0.005
            hidden_neurons = []
            print("l1 weight", l1_weight)
            print("hidden neurons", hidden_neurons)
            model = XLogicNN(n_classes=n_classes, n_features=n_features, name=name, dummy_phi_in=False,
                            hidden_neurons=hidden_neurons, loss=torch.nn.BCELoss(), l1_weight=l1_weight, prune_quantile=0.99)
            # model.need_pruning=False
            try:
                print(device)
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, epochs=epochs, l_r=l_r, verbose=True, batch_size=128, num_workers=2,
                          metric=metric, lr_scheduler=lr_scheduler, device=device, save=True)
            model.model[0].phi_in.tau=100000
            
            outputs, labels = model.predict(test_data, device=device)
            accuracy = model.evaluate(test_data, metric=metric, outputs=outputs, labels=labels)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [], [], [], []
            for i in trange(n_classes, desc=f"{method} extracting explanations"):
                explanation = model.get_global_explanation(X_train, y_train, i,
                                                           top_k_explanations=top_k_explanations,
                                                           concept_names=concept_names,
                                                           metric=expl_metric, x_val=X_val, y_val=y_val)
                exp_accuracy, exp_predictions = test_explanation(explanation, i, X_test, y_test,
                                                                 metric=expl_metric, concept_names=concept_names, inequalities=True)
                exp_predictions = torch.as_tensor(exp_predictions)
                class_output = torch.as_tensor(outputs.argmax(dim=1) == i)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                explanation_complexity = complexity(explanation)
                explanations.append(explanation), exp_accuracies.append(exp_accuracy)
                exp_fidelities.append(exp_fidelity), exp_complexities.append(explanation_complexity)
            for i in range(n_classes):
                exp_accuracy, exp_predictions = test_explanation(explanations[i], i, X_test, y_test,
                                                                                 metric=expl_metric, concept_names=concept_names, inequalities=True)
                exp_predictions = torch.as_tensor(exp_predictions)
                outputs = model(X_test.to(device)).cpu()
                class_output = torch.as_tensor(outputs[:,i] > 0.5)
                exp_fidelity = fidelity(exp_predictions, class_output, expl_metric)
                print(i, exp_fidelity)

        elif method == 'RandomForest':
            set_seed(seed)
            model = XRandomForestClassifier(name=name, n_classes=n_classes,
                                            n_features=n_features)
            try:
                model.load(device)
                print(f"Model {name} already trained")
            except (ClassifierNotTrainedError, IncompatibleClassifierError):
                model.fit(train_data, val_data, epochs=epochs, l_r=l_r, metric=metric,
                          lr_scheduler=lr_scheduler, device=device, save=True, verbose=True)
            accuracy = model.evaluate(test_data, metric=metric)
            explanations, exp_accuracies, exp_fidelities, exp_complexities = [""], [0], [0], [0]

        else:
            raise NotImplementedError(f"{method} not implemented")

        if model.time is None:
            elapsed_time = time.time() - start_time
            # In DeepRed and BRL the training is parallelized to speed up operation
            if method == "BRL":
                elapsed_time = elapsed_time * n_processes
            model.time = elapsed_time
            # To save the elapsed time and the explanations
            model.save(device)
        else:
            elapsed_time = model.time

        # Restore original folder
        if method == "DeepRed":
            model.finish()

        methods.append(method)
        splits.append(seed)
        model_explanations.append(explanations[0])
        model_accuracies.append(accuracy)
        elapsed_times.append(elapsed_time)
        explanation_accuracies.append(np.mean(exp_accuracies))
        explanation_fidelities.append(np.mean(exp_fidelities))
        explanation_complexities.append(np.mean(exp_complexities))
        print("Test model accuracy", accuracy)
        print("Explanation time", elapsed_time)
        print("Explanation accuracy mean", np.mean(exp_accuracies))
        print("Explanation fidelity mean", np.mean(exp_fidelities))
        print("Explanation complexity mean", np.mean(exp_complexities))

    explanation_consistency = formula_consistency(model_explanations)
    print(f'Consistency of explanations: {explanation_consistency:.4f}')

    results = pd.DataFrame({
        'method': methods,
        'split': splits[:len(model_explanations)],
        'explanation': model_explanations,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_fidelity': explanation_fidelities,
        'explanation_complexity': explanation_complexities,
        'explanation_consistency': [explanation_consistency] * len(seeds),
        'elapsed_time': elapsed_times,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    print(results)

## Summary

In [None]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_fidelity', 'explanation_complexity', 'elapsed_time',
        'explanation_consistency']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]
method_list = ['Ent']
results_df = {}
summaries = {}
for m in method_list:
    results_df[m] = pd.read_csv(os.path.join(results_dir, f"results_{m}.csv"))
    df_mean = results_df[m][cols].mean()
    df_sem = results_df[m][cols].sem()
    df_mean.columns = mean_cols
    df_sem.columns = sem_cols
    summaries[m] = pd.concat([df_mean, df_sem])
    summaries[m].name = m

results_df = pd.concat([results_df[method] for method in method_list])
results_df.to_csv(os.path.join(results_dir, f'results_{method_list}.csv'))

summary = pd.concat([summaries[method] for method in method_list], axis=1).T
summary.columns = mean_cols + sem_cols
summary.to_csv(os.path.join(results_dir, f'summary_{method_list}.csv'))
print(summary)