In [9]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sympy import simplify_logic

from lens.utils.base import validate_network
from lens.utils.relu_nn import get_reduced_model, prune_features
from lens import logic
import lens

torch.manual_seed(0)
np.random.seed(0)

In [10]:
x = pd.read_csv('dsprites_c_train.csv', index_col=0)
y = pd.read_csv('dsprites_y_train.csv', index_col=0)

In [11]:
base_concepts = ['color', 'shape', 'scale', 'rotation', 'x_pos', 'y_pos']
base_concepts

['color', 'shape', 'scale', 'rotation', 'x_pos', 'y_pos']

In [109]:
colors = ['white']
shapes = ['square', 'ellipse', 'heart']
scale = ['very small', 'small', 's-medium', 'b-medium', 'big', 'very big']
rotation = ['0°', '5°', '10°', '15°', '20°', '25°', '30°', '35°']
x_pos = ['x0', 'x2', 'x4', 'x6', 'x8', 'x10', 'x12', 'x14', 'x16', 'x18', 'x20', 'x22', 'x24', 'x26', 'x28', 'x30']
y_pos = ['y0', 'y2', 'y4', 'y6', 'y8', 'y10', 'y12', 'y14', 'y16', 'y18', 'y20', 'y22', 'y24', 'y26', 'y28', 'y30']
concepts = colors + shapes + scale + rotation + x_pos + y_pos

In [110]:
x_train = torch.tensor(x.values, dtype=torch.float)
print(x_train.shape)
print(y_train.shape)
print(n_classes)
x

torch.Size([5530, 50])
torch.Size([5530, 18])
18


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5525,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5526,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5527,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5528,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [111]:
y_train = torch.zeros((y.shape[0], y.shape[1]), dtype=torch.float)
y_train = torch.tensor(y.values, dtype=torch.float)
x_test = x_train
n_classes = y_train.size(1)
print(n_classes)
y_train

18


tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [112]:
y_train.sum(dim=0)

tensor([321., 311., 295., 324., 286., 295., 325., 302., 302., 319., 293., 297.,
        302., 301., 329., 326., 321., 281.])

In [113]:
torch.manual_seed(0)
np.random.seed(0)

layers = [
    torch.nn.Linear(x_train.size(1), 20 * n_classes),
    torch.nn.LeakyReLU(),
    lens.nn.XLinear(20, 10, n_classes),
    torch.nn.LeakyReLU(),
    lens.nn.XLinear(10, 5, n_classes),
    torch.nn.LeakyReLU(),
    lens.nn.XLinear(5, 1, n_classes),
    torch.nn.Softmax(),
]
model = torch.nn.Sequential(*layers)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_form = torch.nn.BCELoss()
model.train()
need_pruning = True
for epoch in range(6000):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = loss_form(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.001 * torch.norm(module.weight, 1)
            break

    # backward pass
    loss.backward()
    optimizer.step()
    
    if epoch > 3000 and need_pruning:
        prune_features(model, n_classes)
        #need_pruning = False

    # compute accuracy
    if epoch % 500 == 0:
        y_pred_d = torch.argmax(y_pred, dim=1)
        y_train_d = torch.argmax(y_train, dim=1)
        accuracy = y_pred_d.eq(y_train_d).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')

  input = module(input)


Epoch 0: train accuracy: 0.0588
Epoch 500: train accuracy: 0.8325
Epoch 1000: train accuracy: 0.9467
Epoch 1500: train accuracy: 0.9467
Epoch 2000: train accuracy: 0.9467
Epoch 2500: train accuracy: 1.0000
Epoch 3000: train accuracy: 1.0000
Epoch 3500: train accuracy: 1.0000
Epoch 4000: train accuracy: 1.0000
Epoch 4500: train accuracy: 1.0000
Epoch 5000: train accuracy: 1.0000
Epoch 5500: train accuracy: 1.0000


# Local explanations

In [114]:
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, (xin, yin) in enumerate(zip(x_train, y_train)):
    model_reduced = get_reduced_model(model, xin)
    for module in model_reduced.children():
        if isinstance(module, torch.nn.Linear):
            wa = module.weight.detach().numpy()
            break
    output = model_reduced(xin)
    
    pred_class = torch.argmax(output)
    true_class = torch.argmax(y_train[i])

    # generate local explanation only if the prediction is correct
    if pred_class.eq(true_class):
        local_explanation = logic.relu_nn.explain_local(model, x_train, y_train, xin, concepts)
        print(f'Input {(i+1)}')
        print(f'\tx={xin.detach().numpy()}')
        print(f'\ty={y_train[i].detach().numpy()}')
        print(f'\ty={output.detach().numpy()}')
        #print(f'\tw={wa}')
        print(f'\tExplanation: {local_explanation}')
        print()
    outputs.append(output)
    if i > 1:
        break

Input 1
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0.]
	y=[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
	y=[1.   0.02 0.04 0.02 0.03 0.46 0.56 0.   0.   0.   0.   0.   0.48 0.
 0.62 0.   0.04 0.  ]
	Explanation: square & very small

Input 2
	x=[1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0.]
	y=[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
	y=[0.72 0.   0.   0.   0.   0.   1.   0.47 0.36 0.55 0.48 0.33 0.47 0.
 0.62 0.   0.04 0.  ]
	Explanation: ellipse & very small

Input 3
	x=[1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
	y=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
	y=[0.   0.   0.   0.   0.55 0.   0.   0.   0.   0.   0.53 0.   0.   

# Combine local explanations

In [116]:
y_train_d = torch.argmax(y_train, dim=1)
for target_class in range(n_classes):
    global_explanation, predictions, counter = logic.combine_local_explanations(model, x_train, y_train,
                                                                              topk_explanations=10, 
                                                                              target_class=target_class,
                                                                              concept_names=concepts)


    y2 = torch.argmax(y_train, dim=1) == target_class
    accuracy = sum(predictions == y2.detach().numpy().squeeze()) / len(predictions)
    print(f'Class {target_class} - Global explanation: "{global_explanation}" - Accuracy: {accuracy:.4f}')

Class 0 - Global explanation: "square & very small" - Accuracy: 1.0000
Class 1 - Global explanation: "square & small" - Accuracy: 1.0000
Class 2 - Global explanation: "square & s-medium" - Accuracy: 1.0000
Class 3 - Global explanation: "square & b-medium" - Accuracy: 1.0000
Class 4 - Global explanation: "square & big" - Accuracy: 1.0000
Class 5 - Global explanation: "square & very big" - Accuracy: 1.0000
Class 6 - Global explanation: "ellipse & very small" - Accuracy: 1.0000
Class 7 - Global explanation: "ellipse & small" - Accuracy: 1.0000
Class 8 - Global explanation: "ellipse & s-medium" - Accuracy: 1.0000
Class 9 - Global explanation: "ellipse & b-medium" - Accuracy: 1.0000
Class 10 - Global explanation: "ellipse & big" - Accuracy: 1.0000
Class 11 - Global explanation: "ellipse & very big" - Accuracy: 1.0000
Class 12 - Global explanation: "heart & very small" - Accuracy: 1.0000
Class 13 - Global explanation: "heart & small" - Accuracy: 1.0000
Class 14 - Global explanation: "heart" 