In [1]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sympy import simplify_logic

from lens.utils.relu_nn import get_reduced_model, prune_features
from lens import logic
from lens.utils.base import collect_parameters

torch.manual_seed(0)
np.random.seed(0)

In [2]:
gene_expression_matrix = pd.read_csv('reduced_w_1/data.csv', index_col=None, header=None)
labels = pd.read_csv('reduced_w_1/tempLabels_W-1.csv', index_col=None, header=None)
genes = pd.read_csv('reduced_w_1/features.csv', index_col=None, header=None)

In [3]:
gene_expression_matrix

Unnamed: 0,0,1,2,3,4
0,3.320000,3.320000,3.32000,6.941536,6.590419
1,4.232978,3.320000,3.32000,7.279548,6.476784
2,3.320000,4.200609,3.32000,7.741600,4.643134
3,3.320000,3.320000,3.32000,7.276600,5.953452
4,3.320000,3.320000,3.32000,7.224628,6.555227
...,...,...,...,...,...
56,3.320000,3.320000,3.32000,7.660182,6.128603
57,3.320000,3.700430,3.45131,7.809826,6.153968
58,3.320000,3.320000,3.32000,7.580588,6.134398
59,4.174319,3.320000,3.32000,7.016004,7.124143


In [4]:
labels

Unnamed: 0,0
0,diagnosis: healthy control
1,diagnosis: healthy control
2,diagnosis: healthy control
3,diagnosis: healthy control
4,diagnosis: healthy control
...,...
56,omalizumab responder status: Responder
57,omalizumab responder status: Responder
58,omalizumab responder status: Responder
59,omalizumab responder status: Responder


In [5]:
encoder = LabelEncoder()
labels_encoded = encoder.fit_transform(labels.values)
labels_encoded_noncontrols = labels_encoded[labels_encoded!=0] - 1

data_controls = gene_expression_matrix[labels_encoded==0]
data = gene_expression_matrix[labels_encoded!=0]

gene_signature = data_controls.mean(axis=0)
data_scaled = data - gene_signature

scaler = MinMaxScaler((0, 1))
scaler.fit(data_scaled)
data_normalized = scaler.transform(data_scaled)

x_train = torch.FloatTensor(data_normalized)
y_train = torch.FloatTensor(labels_encoded_noncontrols).unsqueeze(1)
print(x_train.shape)
print(y_train.shape)

torch.Size([40, 5])
torch.Size([40, 1])


  return f(*args, **kwargs)


In [6]:
torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x_train = x_train.to(device)
y_train = y_train.to(device)

layers = [
    torch.nn.Linear(x_train.size(1), 10, bias=True),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(10, 5, bias=True),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(5, 1, bias=True),
    torch.nn.Sigmoid(),
]
model = torch.nn.Sequential(*layers).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
need_pruning = True
for epoch in range(1, 3001):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = torch.nn.functional.binary_cross_entropy(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.0 * torch.norm(module.weight, 1)

    # backward pass
    loss.backward()
    optimizer.step()

    # compute accuracy
    if epoch % 1000 == 0:
        y_pred_d = (y_pred > 0.5)
        accuracy = (y_pred_d.eq(y_train).sum(dim=1) == y_train.size(1)).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')
        
    if epoch > 8000 and need_pruning and epoch % 3000 == 0:
        prune_features(model, 1, device)
        need_pruning = True

Epoch 1000: train accuracy: 1.0000
Epoch 2000: train accuracy: 1.0000
Epoch 3000: train accuracy: 1.0000


## Local explanations

In [7]:
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, (xin, yin) in enumerate(zip(x_train, y_train)):
    model_reduced = get_reduced_model(model, xin.to(device), bias=False).to(device)
    for module in model_reduced.children():
        if isinstance(module, torch.nn.Linear):
            wa = module.weight.cpu().detach().numpy()
            break
    output = model_reduced(xin)
    
    pred_class = torch.argmax(output)
    true_class = torch.argmax(y_train[i])

    # generate local explanation only if the prediction is correct
    if pred_class.eq(true_class):
        local_explanation = logic.relu_nn.explain_local(model.to(device), x_train, y_train, xin, yin, device=device)
        print(f'Input {(i+1)}')
        print(f'\tx={xin.cpu().detach().numpy()}')
        print(f'\ty={y_train[i].cpu().detach().numpy()}')
        print(f'\ty={output.cpu().detach().numpy()}')
        #print(f'\tw={wa}')
        print(f'\tExplanation: {local_explanation}')
        print()
    outputs.append(output)
    if i > 1:
        break

Input 1
	x=[1.   0.   0.06 0.55 0.07]
	y=[0.]
	y=[0.]
	Explanation: feature0000000000 & ~feature0000000001 & ~feature0000000002 & feature0000000003 & ~feature0000000004

Input 2
	x=[0.13 0.89 0.35 0.49 0.47]
	y=[0.]
	y=[0.]
	Explanation: ~feature0000000000 & feature0000000001 & ~feature0000000002 & ~feature0000000003 & ~feature0000000004

Input 3
	x=[0.72 0.38 0.   0.69 0.  ]
	y=[0.]
	y=[0.]
	Explanation: feature0000000000 & ~feature0000000001 & ~feature0000000002 & feature0000000003 & ~feature0000000004



# Combine local explanations

In [8]:
global_explanation, predictions, counter = logic.combine_local_explanations(model, x=x_train, y=y_train, 
                                                                          target_class=0,
                                                                          topk_explanations=10, device=device)

ynp = y_train.cpu().detach().numpy()[:, 0]
accuracy = np.sum(predictions == ynp) / len(ynp)
print(f'Accuracy of when using the formula "{global_explanation}": {accuracy:.4f}')

Accuracy of when using the formula "(~feature0000000000 & ~feature0000000001 & ~feature0000000002) | (feature0000000000 & feature0000000004 & ~feature0000000002 & ~feature0000000003) | (feature0000000003 & feature0000000004 & ~feature0000000000 & ~feature0000000001) | (feature0000000003 & feature0000000004 & ~feature0000000000 & ~feature0000000002) | (feature0000000003 & ~feature0000000001 & ~feature0000000002 & ~feature0000000004) | (~feature0000000000 & ~feature0000000002 & ~feature0000000003 & ~feature0000000004)": 0.7750


In [9]:
global_explanation = logic.relu_nn.explain_global(model, n_classes=1, target_class=0, device=device)
explanation = logic.relu_nn.explain_global(model, n_classes=1, target_class=0, device=device)
if explanation not in ['False', 'True', 'The formula is too complex!']:
    accuracy, _ = logic.relu_nn.test_explanation(explanation, target_class=0, x=x_train.cpu(), y=y_train.cpu())
    print(f'Class {0} - Global explanation: "{global_explanation}" - Accuracy: {accuracy:.4f}')

Class 0 - Global explanation: "(feature0000000004 & ~feature0000000000 & ~feature0000000002) | (~feature0000000000 & ~feature0000000001 & ~feature0000000002) | (~feature0000000000 & ~feature0000000002 & ~feature0000000003) | (feature0000000004 & ~feature0000000001 & ~feature0000000002 & ~feature0000000003)" - Accuracy: 0.9250


In [10]:
w, b = collect_parameters(model, device)
feature_weights = w[0]
feature_used_bool = np.sum(np.abs(feature_weights), axis=0) > 0
feature_used = np.nonzero(feature_used_bool)[0]
genes.iloc[feature_used]

Unnamed: 0,0
0,ILMN_3286286
1,ILMN_1775520
2,ILMN_1656849
3,ILMN_1781198
4,ILMN_1665457


ILMN_3286286, ILMN_1775520, ILMN_1656849, ILMN_1781198, ILMN_1665457

In [11]:
sum(y_train == 0).item() / len(y_train)

0.25