In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, LabelEncoder
from sklearn.model_selection import StratifiedKFold

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relunn import get_reduced_model, prune_features
from deep_logic.utils.sigmoidnn import prune_equal_fanin
from deep_logic import logic

results_dir = 'results/omalizumab'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n_rep = 10
tot_epochs = 6001
prune_epochs = 3001

In [2]:
gene_expression_matrix = pd.read_csv('data/omalizumab/reduced_w_1/data.csv', index_col=None, header=None)
labels = pd.read_csv('data/omalizumab/reduced_w_1/tempLabels_W-1.csv', index_col=None, header=None)
genes = pd.read_csv('data/omalizumab/reduced_w_1/features.csv', index_col=None, header=None)
gene_expression_matrix

Unnamed: 0,0,1,2,3,4
0,3.320000,3.320000,3.32000,6.941536,6.590419
1,4.232978,3.320000,3.32000,7.279548,6.476784
2,3.320000,4.200609,3.32000,7.741600,4.643134
3,3.320000,3.320000,3.32000,7.276600,5.953452
4,3.320000,3.320000,3.32000,7.224628,6.555227
...,...,...,...,...,...
56,3.320000,3.320000,3.32000,7.660182,6.128603
57,3.320000,3.700430,3.45131,7.809826,6.153968
58,3.320000,3.320000,3.32000,7.580588,6.134398
59,4.174319,3.320000,3.32000,7.016004,7.124143


In [3]:
encoder = LabelEncoder()
labels_encoded = encoder.fit_transform(labels.values)
labels_encoded_noncontrols = labels_encoded[labels_encoded!=0] - 1

data_controls = gene_expression_matrix[labels_encoded==0]
data = gene_expression_matrix[labels_encoded!=0]

gene_signature = data_controls.mean(axis=0)
data_scaled = data - gene_signature

scaler = MinMaxScaler((0, 1))
scaler.fit(data_scaled)
data_normalized = scaler.transform(data_scaled)

x = torch.FloatTensor(data_normalized)
y = torch.LongTensor(labels_encoded_noncontrols)
print(x.shape)
print(y.shape)

torch.Size([40, 5])
torch.Size([40])


  return f(*args, **kwargs)


In [4]:
concepts = list(genes.values.squeeze())
concepts

['ILMN_3286286',
 'ILMN_1775520',
 'ILMN_1656849',
 'ILMN_1781198',
 'ILMN_1665457']

In [5]:
n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

In [28]:
def train_nn(x_train, y_train, need_pruning, seed, device, relu=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 50),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(50, 20),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(20, 5),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(5, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.005 * torch.norm(module.weight, 1)
                loss += 0.005 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning and epoch % 1000 == 1:
            prune_features(model, n_classes=1, device=device)
            need_pruning = True
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

# General pruning

In [31]:
need_pruning = True
method = 'pruning'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
#     if split not in [5]: continue
    
    model = train_nn(x_train, y_train, need_pruning, split, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, counter = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device), 
                                                                       target_class=target_class,
                                                                       topk_explanations=2,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, counter_inv = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           target_class=target_class,
                                                                           topk_explanations=2,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.8056
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Model's accuracy: 0.7500
	 Class 1 - Global explanation: "~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849" - Accuracy: 1.0000
	 Elapsed time 0.0608365535736084
	 Class 0 - Global explanation: "(ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 & ~ILMN_1656849) | (ILMN_1775520 & ILMN_1781198 & ~ILMN_3286286 & ~ILMN_1656849)" - Accuracy: 0.7500
	 Elapsed time 0.039893388748168945
Split [2/10]
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: t

	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "ILMN_1665457 & ~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849" - Accuracy: 0.2500
	 Elapsed time 0.07280468940734863
	 Class 0 - Global explanation: "(ILMN_1775520 & ILMN_1781198 & ILMN_1665457 & ~ILMN_3286286 & ~ILMN_1656849) | (ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 & ~ILMN_1656849 & ~ILMN_1665457)" - Accuracy: 0.7500
	 Elapsed time 0.04787302017211914


In [32]:
results_pruning = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_pruning.to_csv(os.path.join(results_dir, 'results_pruning.csv'))
results_pruning

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,pruning,0,~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 &...,0.75,1.0,0.75,0.060837,0.039893
1,pruning,1,~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 &...,1.0,1.0,0.75,0.072808,0.036901
2,pruning,2,(ILMN_1781198 & ILMN_1665457 & ~ILMN_3286286 &...,(ILMN_3286286 & ~ILMN_1665457) | (ILMN_1775520...,1.0,0.25,1.0,0.075798,0.054854
3,pruning,3,ILMN_1665457 & ~ILMN_3286286 & ~ILMN_1775520 &...,~ILMN_1665457,1.0,0.5,0.5,0.076303,0.042885
4,pruning,4,~ILMN_3286286 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1656849) ...,1.0,1.0,1.0,0.054853,0.027891
5,pruning,5,~ILMN_1775520 & ~ILMN_1656849,ILMN_1656849 | (ILMN_3286286 & ILMN_1781198 & ...,1.0,1.0,0.75,0.058842,0.040891
6,pruning,6,(ILMN_1781198 & ILMN_1665457 & ~ILMN_3286286 &...,(ILMN_1656849 & ILMN_1781198 & ILMN_1665457 & ...,1.0,0.5,1.0,0.068755,0.03641
7,pruning,7,~ILMN_3286286 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1656849) ...,1.0,0.5,0.75,0.050862,0.024933
8,pruning,8,(ILMN_1781198 & ILMN_1665457 & ~ILMN_3286286 &...,(ILMN_1775520 & ILMN_1781198 & ILMN_1665457 & ...,1.0,0.75,0.75,0.096743,0.058842
9,pruning,9,ILMN_1665457 & ~ILMN_3286286 & ~ILMN_1775520 &...,(ILMN_1775520 & ILMN_1781198 & ILMN_1665457 & ...,1.0,0.25,0.75,0.072805,0.047873


# LIME

In [34]:
need_pruning = False
method = 'lime'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    print(f'Seed [{seed+1}/{n_rep}]')
    
    model = train_nn(x_train, y_train, need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device),
                                                                       topk_explanations=2,
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           topk_explanations=2,
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Seed [1/10]
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.8056
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Model's accuracy: 0.7500
	 Class 1 - Global explanation: "~ILMN_3286286 & ~ILMN_1656849" - Accuracy: 1.0000
	 Elapsed time 23.859111070632935
	 Class 0 - Global explanation: "~ILMN_1665457" - Accuracy: 0.2500
	 Elapsed time 10.126502513885498
Seed [2/10]
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000

In [35]:
results_lime = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_lime.to_csv(os.path.join(results_dir, 'results_lime.csv'))
results_lime

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,lime,0,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,0.75,1.0,0.25,23.859111,10.126503
1,lime,1,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,0.75,1.0,0.25,23.804507,10.15255
2,lime,2,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,0.75,1.0,0.25,24.821045,10.699512
3,lime,3,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,1.0,1.0,0.25,23.93372,10.767164
4,lime,4,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,0.75,1.0,0.25,24.41352,10.604317
5,lime,5,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,1.0,1.0,0.25,24.18871,10.189684
6,lime,6,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,0.75,1.0,0.25,23.701943,10.182435
7,lime,7,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,1.0,1.0,0.25,23.755335,10.342746
8,lime,8,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,0.75,1.0,0.25,23.90011,10.251186
9,lime,9,~ILMN_3286286 & ~ILMN_1656849,~ILMN_1665457,1.0,1.0,0.25,24.399475,10.399325


# Weights

In [36]:
need_pruning = False
method = 'weights'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = torch.FloatTensor(x[train_index]), torch.FloatTensor(x[test_index])
    y_train, y_test = torch.LongTensor(y[train_index]), torch.LongTensor(y[test_index])
    
    model = train_nn(x_train, y_train, need_pruning, split, device, relu=True)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    # positive class
    target_class = 1
    start = time.time()
    global_explanation, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                       x_train.to(device), y_train.to(device),
                                                                       topk_explanations=2, 
                                                                       target_class=target_class,
                                                                       method=method, device=device)
    elapsed_time = time.time() - start
    if global_explanation:
        explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
        explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    start = time.time()
    global_explanation_inv, _, _ = logic.relunn.combine_local_explanations(model, 
                                                                           x_train.to(device), y_train.to(device), 
                                                                           topk_explanations=2, 
                                                                           target_class=target_class,
                                                                           method=method, device=device)
    elapsed_time_inv = time.time() - start
    if global_explanation_inv:
        explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, target_class, x_test, y_test)
        explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Split [1/10]
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.7778
	 Epoch 1000: train accuracy: 0.9444
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 1.0000
	 Epoch 2500: train accuracy: 1.0000
	 Epoch 3000: train accuracy: 1.0000
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849" - Accuracy: 1.0000
	 Elapsed time 0.12666106224060059
	 Class 0 - Global explanation: "(ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 & ~ILMN_1656849) | (ILMN_1775520 & ILMN_1781198 & ~ILMN_3286286 & ~ILMN_1656849)" - Accuracy: 0.7500
	 Elapsed time 0.07678532600402832
Split [2/10]
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 0.9444
	 Epoch 1500: t

In [37]:
results_weights = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_weights.to_csv(os.path.join(results_dir, 'results_weights.csv'))
results_weights

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,weights,0,~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 &...,1.0,1.0,0.75,0.126661,0.076785
1,weights,1,(ILMN_1665457 & ~ILMN_3286286) | (ILMN_1781198...,(ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849)...,1.0,0.75,0.75,0.106716,0.053855
2,weights,2,(ILMN_1781198 & ILMN_1665457 & ~ILMN_3286286 &...,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1656849) ...,1.0,0.25,1.0,0.111238,0.04305
3,weights,3,(ILMN_1665457 & ~ILMN_1656849) | (~ILMN_328628...,(ILMN_3286286 & ~ILMN_1656849) | (ILMN_1656849...,1.0,0.75,0.5,0.111702,0.0369
4,weights,4,ILMN_1665457 | (ILMN_1781198 & ~ILMN_3286286 &...,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1656849) ...,1.0,0.75,1.0,0.084776,0.047873
5,weights,5,~ILMN_1775520 & ~ILMN_1656849,ILMN_1656849 | (ILMN_3286286 & ILMN_1781198 & ...,1.0,1.0,0.75,0.106715,0.05286
6,weights,6,~ILMN_3286286 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1656849) ...,1.0,1.0,1.0,0.106983,0.041888
7,weights,7,~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849,(ILMN_3286286 & ILMN_1781198 & ~ILMN_1775520 &...,1.0,0.75,0.75,0.102726,0.049865
8,weights,8,~ILMN_1775520 & ~ILMN_1656849,~ILMN_1665457,1.0,1.0,1.0,0.124366,0.054853
9,weights,9,ILMN_1665457 & ~ILMN_3286286 & ~ILMN_1775520 &...,~ILMN_1665457,1.0,0.25,0.25,0.109707,0.051861


# Psi network

In [38]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.0001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > 1500 and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [39]:
need_pruning = True
method = 'psi'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []
for seed in range(n_rep):
    print(f'Seed [{seed+1}/{n_rep}]')
    
    # positive class
    target_class = 1
    model = train_psi_nn(x_train, y_train, need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation = logic.generate_fol_explanations(model, device)[0]
    elapsed_time = time.time() - start
    explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
    explanation = logic.base.replace_names(global_explanation, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
    print(f'\t Elapsed time {elapsed_time}')
        
    # negative class
    target_class = 0
    model = train_psi_nn(x_train, y_train.eq(target_class), need_pruning, seed, device)
    
    y_preds = model(x_test.to(device)).cpu().detach().numpy()
    model_accuracy = accuracy_score(y_test.eq(target_class).cpu().detach().numpy(), y_preds > 0.5)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    start = time.time()
    global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
    elapsed_time_inv = time.time() - start
    explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
                                                              target_class, x_test, y_test)
    explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
    print(f'\t Class {target_class} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(seed)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(explanation_accuracy)
    explanation_accuracies_inv.append(explanation_accuracy_inv)
    elapsed_times.append(elapsed_time)
    elapsed_times_inv.append(elapsed_time_inv)

Seed [1/10]
	 Epoch 0: train accuracy: 0.2500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 0.7500
	 Epoch 1500: train accuracy: 0.9444
	 Epoch 2000: train accuracy: 0.7500
	 Epoch 2500: train accuracy: 0.8056
	 Epoch 3000: train accuracy: 0.9444
	 Epoch 3500: train accuracy: 1.0000
	 Epoch 4000: train accuracy: 1.0000
	 Epoch 4500: train accuracy: 1.0000
	 Epoch 5000: train accuracy: 1.0000
	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: "(~ILMN_1656849 & (~ILMN_3286286 | ~ILMN_1775520))" - Accuracy: 1.0000
	 Elapsed time 0.042885780334472656
	 Epoch 0: train accuracy: 0.7500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 1.0000
	 Epoch 1500: train accuracy: 1.0000
	 Epoch 2000: train accuracy: 0.7500
	 Epoch 2500: train accuracy: 0.8333
	 Epoch 3000: train accuracy: 0.8611
	 Epoch 3500: train accuracy: 0.9167
	 Epoch 4000: train accuracy: 0.9167
	 Epoch 4500: tr

	 Epoch 5500: train accuracy: 1.0000
	 Epoch 6000: train accuracy: 1.0000
	 Model's accuracy: 0.7500
	 Class 1 - Global explanation: "(~ILMN_1775520 & ~ILMN_1656849 & (~ILMN_3286286 | ~ILMN_1781198))" - Accuracy: 0.7500
	 Elapsed time 0.02991938591003418
	 Epoch 0: train accuracy: 0.2500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 0.7500
	 Epoch 1500: train accuracy: 0.7500
	 Epoch 2000: train accuracy: 0.7500
	 Epoch 2500: train accuracy: 0.7500
	 Epoch 3000: train accuracy: 0.8889
	 Epoch 3500: train accuracy: 0.8889
	 Epoch 4000: train accuracy: 0.9167
	 Epoch 4500: train accuracy: 0.8889
	 Epoch 5000: train accuracy: 0.8889
	 Epoch 5500: train accuracy: 0.8889
	 Epoch 6000: train accuracy: 0.9167
	 Model's accuracy: 0.7500
	 Class 0 - Global explanation: "(ILMN_1775520 | ILMN_1656849)" - Accuracy: 0.7500
	 Elapsed time 0.014959573745727539
Seed [8/10]
	 Epoch 0: train accuracy: 0.2500
	 Epoch 500: train accuracy: 0.7500
	 Epoch 1000: train accuracy: 0.7500
	 E

In [40]:
results_psi = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_psi.to_csv(os.path.join(results_dir, 'results_psi.csv'))
results_psi

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,psi,0,(~ILMN_1656849 & (~ILMN_3286286 | ~ILMN_1775520)),(ILMN_3286286 | ILMN_1775520 | ILMN_1656849),0.75,1.0,0.75,0.042886,0.030916
1,psi,1,(~ILMN_1775520 & ~ILMN_1656849 & (ILMN_1665457...,(ILMN_1656849 | (ILMN_3286286 & ILMN_1775520)),1.0,1.0,1.0,0.048871,0.02892
2,psi,2,(~ILMN_3286286 & ~ILMN_1775520 & ~ILMN_1656849),(ILMN_3286286 | ILMN_1775520),0.75,0.75,0.75,0.026927,0.031913
3,psi,3,(~ILMN_1775520 & ~ILMN_1656849 & (ILMN_1665457...,(ILMN_1775520 | (ILMN_3286286 & ILMN_1781198) ...,0.75,1.0,0.75,0.031914,0.031913
4,psi,4,(~ILMN_1656849 & (~ILMN_3286286 | ~ILMN_1775520)),(ILMN_1775520 | ILMN_1656849 | (ILMN_3286286 &...,0.75,1.0,0.75,0.034911,0.027925
5,psi,5,(~ILMN_1656849 & (~ILMN_3286286 | ~ILMN_1775520)),(ILMN_1775520 | ILMN_1656849),0.75,1.0,0.75,0.028922,0.028922
6,psi,6,(~ILMN_1775520 & ~ILMN_1656849 & (~ILMN_328628...,(ILMN_1775520 | ILMN_1656849),0.75,0.75,0.75,0.029919,0.01496
7,psi,7,(~ILMN_1775520 & ~ILMN_1656849),((ILMN_3286286 & ILMN_1775520) | (ILMN_3286286...,1.0,0.75,1.0,0.039895,0.05286
8,psi,8,(~ILMN_1656849 & (~ILMN_3286286 | ~ILMN_1775520)),(ILMN_3286286 | ILMN_1775520 | (ILMN_1656849 &...,0.75,1.0,0.75,0.035906,0.02593
9,psi,9,(~ILMN_1775520 & ~ILMN_1656849),(ILMN_1775520 | ILMN_1656849),0.75,0.75,0.75,0.02992,0.023935


# Decision tree

In [41]:
need_pruning = False
method = 'decision_tree'
methods = []
splits = []
explanations = []
explanations_inv = []
model_accuracies = []
explanation_accuracies = []
explanation_accuracies_inv = []
elapsed_times = []
elapsed_times_inv = []

for split, (train_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
    print(f'Split [{split+1}/{n_splits}]')
    x_train, x_test = x[train_index], x[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    classifier = DecisionTreeClassifier(random_state=split)
    classifier.fit(x_train.cpu().detach().numpy(), y_train.cpu().detach().numpy())
    y_preds = classifier.predict(x_test.cpu().detach().numpy())
    model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)
    print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
    
    target_class = 1
    start = time.time()
    explanation = tree_to_formula(classifier, concepts, target_class)
    elapsed_time = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation}')
    print(f'\t Elapsed time {elapsed_time}')
    
    target_class = 0
    start = time.time()
    explanation_inv = tree_to_formula(classifier, concepts, target_class)
    elapsed_time_inv = time.time() - start
    print(f'\t Class {target_class} - Global explanation: {explanation_inv}')
    print(f'\t Elapsed time {elapsed_time_inv}')
    
    methods.append(method)
    splits.append(split)
    explanations.append(explanation)
    explanations_inv.append(explanation_inv)
    model_accuracies.append(model_accuracy)
    explanation_accuracies.append(model_accuracy)
    explanation_accuracies_inv.append(model_accuracy)
    elapsed_times.append(0)
    elapsed_times_inv.append(0)

Split [1/10]
	 Model's accuracy: 1.0000
	 Class 1 - Global explanation: (ILMN_1775520 <= 0.34 & ILMN_3286286 <= 0.86 & ILMN_1781198 <= 0.90 & ILMN_1665457 <= 0.08 & ILMN_1781198 <= 0.55) | (ILMN_1775520 <= 0.34 & ILMN_3286286 <= 0.86 & ILMN_1781198 <= 0.90 & ILMN_1665457 > 0.08) | (ILMN_1775520 > 0.34 & ILMN_3286286 <= 0.07 & ILMN_1781198 <= 0.80)
	 Elapsed time 0.0
	 Class 0 - Global explanation: (ILMN_1775520 <= 0.34 & ILMN_3286286 <= 0.86 & ILMN_1781198 <= 0.90 & ILMN_1665457 <= 0.08 & ILMN_1781198 > 0.55) | (ILMN_1775520 <= 0.34 & ILMN_3286286 <= 0.86 & ILMN_1781198 > 0.90) | (ILMN_1775520 <= 0.34 & ILMN_3286286 > 0.86) | (ILMN_1775520 > 0.34 & ILMN_3286286 <= 0.07 & ILMN_1781198 > 0.80) | (ILMN_1775520 > 0.34 & ILMN_3286286 > 0.07)
	 Elapsed time 0.0
Split [2/10]
	 Model's accuracy: 0.7500
	 Class 1 - Global explanation: (ILMN_1775520 <= 0.34 & ILMN_1665457 <= 0.08 & ILMN_1781198 <= 0.46) | (ILMN_1775520 <= 0.34 & ILMN_1665457 > 0.08) | (ILMN_1775520 > 0.34 & ILMN_1665457 > 0.56 &

In [42]:
results_tree = pd.DataFrame({
    'method': methods,
    'split': splits,
    'explanation': explanations,
    'explanation_inv': explanations_inv,
    'model_accuracy': model_accuracies,
    'explanation_accuracy': explanation_accuracies,
    'explanation_accuracy_inv': explanation_accuracies_inv,
    'elapsed_time': elapsed_times,
    'elapsed_time_inv': elapsed_times_inv,
})
results_tree.to_csv(os.path.join(results_dir, 'results_tree.csv'))
results_tree

Unnamed: 0,method,split,explanation,explanation_inv,model_accuracy,explanation_accuracy,explanation_accuracy_inv,elapsed_time,elapsed_time_inv
0,decision_tree,0,(ILMN_1775520 <= 0.34 & ILMN_3286286 <= 0.86 &...,(ILMN_1775520 <= 0.34 & ILMN_3286286 <= 0.86 &...,1.0,1.0,1.0,0,0
1,decision_tree,1,(ILMN_1775520 <= 0.34 & ILMN_1665457 <= 0.08 &...,(ILMN_1775520 <= 0.34 & ILMN_1665457 <= 0.08 &...,0.75,0.75,0.75,0,0
2,decision_tree,2,(ILMN_1656849 <= 0.18 & ILMN_3286286 <= 0.46) ...,(ILMN_1656849 <= 0.18 & ILMN_3286286 > 0.46 & ...,1.0,1.0,1.0,0,0
3,decision_tree,3,(ILMN_3286286 <= 0.07 & ILMN_1656849 <= 0.35) ...,(ILMN_3286286 <= 0.07 & ILMN_1656849 > 0.35) |...,0.5,0.5,0.5,0,0
4,decision_tree,4,(ILMN_1656849 <= 0.18 & ILMN_3286286 <= 0.46) ...,(ILMN_1656849 <= 0.18 & ILMN_3286286 > 0.46 & ...,1.0,1.0,1.0,0,0
5,decision_tree,5,(ILMN_1656849 <= 0.17 & ILMN_3286286 <= 0.54) ...,(ILMN_1656849 <= 0.17 & ILMN_3286286 > 0.54 & ...,0.75,0.75,0.75,0,0
6,decision_tree,6,(ILMN_1656849 <= 0.18 & ILMN_3286286 <= 0.46) ...,(ILMN_1656849 <= 0.18 & ILMN_3286286 > 0.46 & ...,1.0,1.0,1.0,0,0
7,decision_tree,7,(ILMN_1656849 <= 0.21 & ILMN_3286286 <= 0.46) ...,(ILMN_1656849 <= 0.21 & ILMN_3286286 > 0.46 & ...,1.0,1.0,1.0,0,0
8,decision_tree,8,(ILMN_1656849 <= 0.18 & ILMN_3286286 <= 0.46) ...,(ILMN_1656849 <= 0.18 & ILMN_3286286 > 0.46 & ...,1.0,1.0,1.0,0,0
9,decision_tree,9,(ILMN_1656849 <= 0.18 & ILMN_1665457 <= 0.08 &...,(ILMN_1656849 <= 0.18 & ILMN_1665457 <= 0.08 &...,0.75,0.75,0.75,0,0


# Summary

In [43]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# lime
df_mean = results_lime[cols].mean()
df_sem = results_lime[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_lime = pd.concat([df_mean, df_sem])
summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
                     summary_lime, 
                     summary_weights, 
                     summary_psi, 
                     summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

Unnamed: 0,model_accuracy_mean,explanation_accuracy_mean,explanation_accuracy_inv_mean,elapsed_time_mean,elapsed_time_inv_mean,model_accuracy_sem,explanation_accuracy_sem,explanation_accuracy_inv_sem,elapsed_time_sem,elapsed_time_inv_sem
pruning,0.975,0.675,0.8,0.068861,0.041138,0.025,0.098953,0.05,0.004212,0.003383
lime,0.85,1.0,0.25,24.077748,10.371542,0.040825,0.0,0.0,0.115518,0.075365
weights,1.0,0.75,0.775,0.109159,0.050979,0.0,0.091287,0.078617,0.003653,0.003413
psi,0.8,0.9,0.8,0.035007,0.029819,0.033333,0.040825,0.033333,0.002218,0.003012
tree,0.875,0.875,0.875,0.0,0.0,0.055902,0.055902,0.055902,0.0,0.0


In [44]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))