In [1]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sympy import simplify_logic
from sklearn.datasets import load_digits
from sklearn.preprocessing import OneHotEncoder
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from lens.utils.base import validate_network
from lens.utils.relu_nn import get_reduced_model, prune_features
from lens import logic

torch.manual_seed(0)
np.random.seed(0)

In [2]:
X, y = load_digits(return_X_y=True)

print(f'X shape: {X.shape}\nClasses: {np.unique(y)}')

X shape: (1797, 64)
Classes: [0 1 2 3 4 5 6 7 8 9]


In [3]:
concepts = [f'f{i}' for i in np.unique(y)]
concepts

['f0', 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9']

In [4]:
enc = OneHotEncoder()
y1h = enc.fit_transform(y.reshape(-1, 1)).toarray()


y2 = np.zeros((len(y), 2))
for i, yi in enumerate(y):
    if yi % 2:
        y2[i, 0] = 1
    else:
        y2[i, 1] = 1
y2 = y2[:, 0]

print(f'Target vector shape: {y1h.shape}')
for i in range(10):
    print(f'Example ({y[i]}): {y1h[i]}')

Target vector shape: (1797, 10)
Example (0): [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Example (1): [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
Example (2): [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
Example (3): [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
Example (4): [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
Example (5): [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
Example (6): [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
Example (7): [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
Example (8): [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
Example (9): [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]


In [5]:
X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(y1h, y2, test_size=0.33, random_state=42)
x_train = torch.FloatTensor(X_train_np)
y_train = torch.FloatTensor(y_train_np).unsqueeze(1)
x_test = torch.FloatTensor(X_test_np)
y_test = torch.FloatTensor(y_test_np).unsqueeze(1)
y_train.shape

torch.Size([1203, 1])

In [6]:
layers = [
    torch.nn.Linear(x_train.size(1), 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 30),
    torch.nn.ReLU(),
    torch.nn.Linear(30, 1),
    torch.nn.Sigmoid(),
]
model = torch.nn.Sequential(*layers)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
need_pruning = True
for epoch in range(6000):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = torch.nn.functional.mse_loss(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.002 * torch.norm(module.weight, 1)

    # backward pass
    loss.backward()
    optimizer.step()

    # compute accuracy
    if epoch % 100 == 0:
        y_pred_d = (y_pred > 0.5)
        accuracy = (y_pred_d.eq(y_train).sum(dim=1) == y_train.size(1)).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    
    if epoch > 3000 and need_pruning:
        prune_features(model)
        need_pruning = False

Epoch 0: train accuracy: 0.3990
Epoch 100: train accuracy: 1.0000
Epoch 200: train accuracy: 1.0000
Epoch 300: train accuracy: 1.0000
Epoch 400: train accuracy: 1.0000
Epoch 500: train accuracy: 1.0000
Epoch 600: train accuracy: 1.0000
Epoch 700: train accuracy: 1.0000
Epoch 800: train accuracy: 1.0000
Epoch 900: train accuracy: 1.0000
Epoch 1000: train accuracy: 1.0000
Epoch 1100: train accuracy: 1.0000
Epoch 1200: train accuracy: 1.0000
Epoch 1300: train accuracy: 1.0000
Epoch 1400: train accuracy: 1.0000
Epoch 1500: train accuracy: 1.0000
Epoch 1600: train accuracy: 1.0000
Epoch 1700: train accuracy: 1.0000
Epoch 1800: train accuracy: 1.0000
Epoch 1900: train accuracy: 1.0000
Epoch 2000: train accuracy: 1.0000
Epoch 2100: train accuracy: 1.0000
Epoch 2200: train accuracy: 1.0000
Epoch 2300: train accuracy: 1.0000
Epoch 2400: train accuracy: 1.0000
Epoch 2500: train accuracy: 1.0000
Epoch 2600: train accuracy: 1.0000
Epoch 2700: train accuracy: 1.0000
Epoch 2800: train accuracy: 1.00

# Local explanations

In [7]:
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, (xin, yin) in enumerate(zip(x_train, y_train)):
    model_reduced = get_reduced_model(model, xin)
    for module in model_reduced.children():
        if isinstance(module, torch.nn.Linear):
            wa = module.weight.detach().numpy()
            ba = module.bias.detach().numpy()
            break
    output = model_reduced(xin)
    if (yin.eq(output>0.5) or yin.eq(output<0.5)) and yin > 0:
        local_explanation = logic.relu_nn.explain_local(model, x_train, y_train, xin, concepts)
        print(f'Input {(i+1)}')
        print(f'\tx={xin.detach().numpy()}')
        print(f'\ty={output.detach().numpy()}')
        print(f'\tw={wa}')
        print(f'\tb={ba}')
        print(f'\tExplanation: {local_explanation}')
        print()
    outputs.append(output)
    if i > 50:
        break

Input 1
	x=[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.4   3.41 -3.43  3.45 -3.42  3.42]]
	b=[-0.03]
	Explanation: f1

Input 2
	x=[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.39  3.41 -3.42  3.45 -3.41  3.43]]
	b=[-0.03]
	Explanation: f1

Input 4
	x=[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.39  3.41 -3.42  3.45 -3.41  3.43]]
	b=[-0.03]
	Explanation: f9

Input 5
	x=[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.39  3.41 -3.42  3.45 -3.41  3.43]]
	b=[-0.03]
	Explanation: f3

Input 8
	x=[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.39  3.41 -3.42  3.45 -3.41  3.43]]
	b=[-0.03]
	Explanation: f9

Input 9
	x=[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.39  3.41 -3.42  3.45 -3.41  3.43]]
	b=[-0.03]
	Explanation: f1

Input 10
	x=[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
	y=[0.97]
	w=[[-3.42  3.49 -3.43  3.49 -3.39  3.41 -3.42  3.45

# Combine local explanations

In [8]:
global_explanation, predictions, counter = logic.combine_local_explanations(model, x_train, y_train,
                                                                          topk_explanations=5, 
                                                                          concept_names=concepts)

ynp = y_train.detach().numpy()[:, 0]
accuracy = np.sum(predictions == ynp) / len(ynp)
print(f'Accuracy using the formula "{global_explanation}": {accuracy:.4f}')

Accuracy using the formula "f1 | f3 | f5 | f7 | f9": 1.0000


In [9]:
counter

Counter({'f1': 127, 'f9': 112, 'f3': 127, 'f7': 117, 'f5': 109})

In [10]:
pd.DataFrame({
    'predictions': predictions.ravel(),
    'labels': y_train.detach().numpy().ravel(),
})

Unnamed: 0,predictions,labels
0,True,1.0
1,True,1.0
2,False,0.0
3,True,1.0
4,True,1.0
...,...,...
1198,True,1.0
1199,True,1.0
1200,False,0.0
1201,True,1.0
