In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

from deep_logic.utils.base import validate_network, set_seed, tree_to_formula
from deep_logic.utils.relu_nn import get_reduced_model, prune_features
from deep_logic.utils.psi_nn import prune_equal_fanin
from deep_logic import logic

results_dir = './results/mnist'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
concepts = [f'c{i}' for i in range(10)]
n_rep = 10
tot_epochs = 4001
prune_epochs = 2000
concepts

['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']

In [2]:
# MNIST problem
num_workers = 0
batch_size = 128
valid_size = 0.2
# Data augmentation for train data + conversion to tensor
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
   
])# Data augmentation for test data + conversion to tensor
test_transforms= transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=train_transforms)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=test_transforms)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

KeyboardInterrupt: 

In [None]:
# Finding indices for validation set
num_train = len(train_data)
indices = list(range(num_train))
#Randomize indices
np.random.shuffle(indices)
split = int(np.floor(num_train*valid_size))
train_index, test_index = indices[split:], indices[:split]# Making samplers for training and validation batches
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(test_index)# Creating data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # convolutional layers
        self.conv1 = torch.nn.Conv2d(1, 8, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(8, 16, 3, padding =1)
        # linear layers
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 128)
        self.fc3 = torch.nn.Linear(128, 64)
        self.fc4 = torch.nn.Linear(64, 10) 
        # dropout
        self.dropout = torch.nn.Dropout(p=0.2)
        # max pooling
        self.pool = torch.nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        # convolutional layers with ReLU and pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # flattening the image
        x = x.view(-1, 7*7*16)
        # linear layers
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = self.fc4(x)
        return F.softmax(x, dim=1)
        
model = Net()
print(model)
model.cuda()

In [None]:
criterion = torch.nn.CrossEntropyLoss()
if os.path.isfile('./models/mnist/trained_model.pt'):
    model.load_state_dict(torch.load('./models/mnist/trained_model.pt'))

else:
    # epochs to train for
    epochs = 25
    set_seed(0)

    # tracks validation loss change after each epoch
    minimum_validation_loss = np.inf 

    #optimizer = torch.optim.SGD(model.parameters(), lr = 0.001)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(1, epochs+1):

        train_loss = 0
        valid_loss = 0

        # training steps
        model.train()
        for batch_index, (data, target) in enumerate(train_loader):
            # moves tensors to GPU
            data, target = data.cuda(), target.cuda()
            # clears gradients
            optimizer.zero_grad()
            # forward pass
            output = model(data)
            # loss in batch
            loss = criterion(output, target)
            # backward pass for loss gradient
            loss.backward()
            # update paremeters
            optimizer.step()
            # update training loss
            train_loss += loss.item()*data.size(0)

        # validation steps
        model.eval()
        for batch_index, (data, target) in enumerate(valid_loader):
            # moves tensors to GPU
            data, target = data.cuda(), target.cuda()
            # forward pass
            output = model(data)
            # loss in batch
            loss = criterion(output, target)
            # update validation loss
            valid_loss += loss.item()*data.size(0)

        # average loss calculations
        train_loss = train_loss/len(train_loader.sampler)
        valid_loss = valid_loss/len(valid_loader.sampler)

        # Display loss statistics
        print(f'Current Epoch: {epoch}\nTraining Loss: {round(train_loss, 6)}\nValidation Loss: {round(valid_loss, 6)}')

        # Saving model every time validation loss decreases
        if valid_loss <= minimum_validation_loss and epoch > 20:
            print(f'Validation loss decreased from {round(minimum_validation_loss, 6)} to {round(valid_loss, 6)}')
            torch.save(model.state_dict(), './models/mnist/trained_model.pt')
            minimum_validation_loss = valid_loss
            print('Saving New Model')
    
    model.load_state_dict(torch.load('./models/mnist/trained_model.pt'))

In [None]:
# tracking test loss
test_loss = 0.0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

model.eval()
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
for batch_idx, (data, target) in enumerate(test_loader):
    # move tensors to GPU
    data, target = data.cuda(), target.cuda()
    # forward pass
    output = model(data)
    # batch loss
    loss = criterion(output, target)
    # test loss update
    test_loss += loss.item()*data.size(0)
    # convert output probabilities to predicted class
    _, pred = torch.max(output, 1)    
    # compare predictions to true label
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())
    # calculate test accuracy for each object class
    for i in range(10):
        label = target.data[i]
        class_correct[label] += correct[i].item()
        class_total[label] += 1

# average test loss
test_loss = test_loss/len(test_loader.dataset)
print(f'Test Loss: {round(test_loss, 6)}')

for i in range(10):
    if class_total[i] > 0:
        print(f'Test Accuracy of {classes[i]}: {round(100*class_correct[i]/class_total[i], 2)}%')
    else:
        print(f'Test Accuracy of {classes[i]}s: N/A (no training examples)')
        
        
print(f'Full Test Accuracy: {round(100. * np.sum(class_correct) / np.sum(class_total), 2)}% {np.sum(class_correct)} out of {np.sum(class_total)}')

In [None]:
pred_c_train = []
true_c_train = []
model.eval()
for batch_index, (data, target) in enumerate(train_loader):
    # moves tensors to GPU
    data, target = data.cuda(), target.cuda()
    # forward pass
    output = model(data)
    pred = torch.argmax(output, 1)
    pred_c_train.append(pred)
    true_c_train.append(target)

In [None]:
pred_c_train = torch.cat(pred_c_train)
true_c_train = torch.cat(true_c_train)
pred_c_train.shape

In [None]:
y_trainval = (true_c_train % 2 == 1).to(torch.long)
y_trainval.shape

In [None]:
x_trainval = OneHotEncoder(sparse=False).fit_transform(pred_c_train.cpu().detach().numpy().reshape(-1, 1))
x_trainval = torch.FloatTensor(x_trainval)
x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=300, random_state=42)
print(x_train.shape)
print(x_val.shape)

In [None]:
pred_c_test = []
true_c_test = []
model.eval()
for batch_index, (data, target) in enumerate(test_loader):
    # moves tensors to GPU
    data, target = data.cuda(), target.cuda()
    # forward pass
    output = model(data)
    pred = torch.argmax(output, 1)
    pred_c_test.append(pred)
    true_c_test.append(target)

In [None]:
pred_c_test = torch.cat(pred_c_test)
true_c_test = torch.cat(true_c_test)
pred_c_test.shape

In [None]:
y_test = (true_c_test % 2 == 1).to(torch.long)
y_test.shape

In [None]:
x_test = OneHotEncoder(sparse=False).fit_transform(true_c_test.cpu().detach().numpy().reshape(-1, 1))
x_test = torch.FloatTensor(x_test)
x_test.shape

In [None]:
def train_nn(x_train, y_train, need_pruning, seed, device, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 100),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(100, 50),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(50, 30),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.001 * torch.norm(module.weight, 1)
                loss += 0.001 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            prune_features(model, n_classes=1, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [None]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.00001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [None]:
def c_to_y(method, need_pruning, relu, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    for seed in range(n_rep):
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        print(f'Seed [{seed+1}/{n_rep}]')
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_train.cpu().detach().numpy(), y_train.cpu().detach().numpy())
            y_preds = classifier.predict(x_test.cpu().detach().numpy())
            model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start
            explanation_accuracy = model_accuracy

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time_inv = time.time() - start
            explanation_accuracy_inv = model_accuracy
        
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                model = train_psi_nn(x_train, y_train.eq(target_class), need_pruning, seed, device, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds>0.5)
                
            else:
                model = train_nn(x_train, y_train, need_pruning, seed, device, relu, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))

            # positive class
            target_class = 1
            start = time.time()
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model, device)[0]
            else:
                global_explanation, _, _ = logic.relu_nn.combine_local_explanations(model, 
                                                                                   x_val.to(device), 
                                                                                   y_val.to(device), 
                                                                                   topk_explanations=5,
                                                                                   target_class=target_class,
                                                                                   method=method, device=device)
            elapsed_time = time.time() - start
            
            if global_explanation:
                explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
                explanation = logic.base.replace_names(global_explanation, concepts)

            # negative class
            target_class_inv = 0
            if method == 'psi':
                model = train_psi_nn(x_train, y_train.eq(target_class_inv), need_pruning, seed, device, verbose)
            
            start = time.time()
            if method == 'psi':
                global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
            else:
                global_explanation_inv, _, _ = logic.relu_nn.combine_local_explanations(model, 
                                                                                       x_val.to(device), 
                                                                                       y_val.to(device), 
                                                                                       topk_explanations=5,
                                                                                       target_class=target_class_inv,
                                                                                       method=method, device=device)
            elapsed_time_inv = time.time() - start
            if global_explanation_inv:
                explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
                                                                          target_class_inv, x_test, y_test)
                explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
            print(f'\t Class {target_class_inv} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
            print(f'\t Elapsed time {elapsed_time_inv}')

        methods.append(method)
        splits.append(seed)
        explanations.append(explanation)
        explanations_inv.append(explanation_inv)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
        explanation_accuracies_inv.append(explanation_accuracy_inv)
        elapsed_times.append(elapsed_time)
        elapsed_times_inv.append(elapsed_time_inv)
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
        'explanation_inv': explanations_inv,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
        'explanation_accuracy_inv': explanation_accuracies_inv,
        'elapsed_time': elapsed_times,
        'elapsed_time_inv': elapsed_times_inv,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [None]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu)
results_pruning

# LIME

In [None]:
# method = 'lime'
# need_pruning = False
# relu = False
# results_lime = c_to_y(method, need_pruning, relu)
# results_lime

# ReLU

In [None]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method, need_pruning, relu)
results_weights

# Psi network

In [None]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, relu, verbose=False)
results_psi

# Decision tree

In [None]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu)
results_tree

# Summary

In [None]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, summary_psi, summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

In [None]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))