In [1]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sympy import simplify_logic

from lens.utils.relu_nn import get_reduced_model, prune_features
from lens import logic
from lens.utils.base import collect_parameters

torch.manual_seed(0)
np.random.seed(0)

In [2]:
gene_expression_matrix = pd.read_csv('w_1/data_0.csv', index_col=None, header=None)
labels = pd.read_csv('w_1/tempLabels_W-1.csv', index_col=None, header=None)
genes = pd.read_csv('w_1/features_0.csv', index_col=None, header=None)

In [3]:
gene_expression_matrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,28392,28393,28394,28395,28396,28397,28398,28399,28400,28401
0,14.622486,11.162004,3.320000,3.320000,3.320000,12.788433,6.143456,3.320000,4.876620,3.320000,...,3.32,3.320000,3.885589,3.914260,3.320000,3.32,3.320000,4.465420,3.320000,4.973620
1,14.398743,11.000080,3.320000,3.320000,3.320000,12.845914,6.147482,3.320000,4.484223,3.320000,...,3.32,3.575025,4.236519,4.047825,3.320000,3.32,4.176269,3.320000,4.553796,4.967418
2,14.692079,11.100175,3.320000,4.171535,3.320000,12.712544,5.583210,3.320000,3.478171,3.320000,...,3.32,3.320000,3.992331,4.865538,3.320000,3.32,3.488281,3.406285,3.320000,6.676063
3,14.613382,11.023209,3.320000,3.320000,3.320000,12.750496,5.688023,3.320000,4.464426,3.320000,...,3.32,3.855643,3.320000,4.905350,3.320000,3.32,4.158393,4.433457,3.874214,5.981160
4,14.482065,10.989851,3.320000,3.992726,4.574745,12.878702,6.195418,4.177962,3.872567,3.320000,...,3.32,3.320000,3.320000,4.879493,3.320000,3.32,4.571869,3.320000,4.982136,6.145585
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56,14.565031,11.699843,3.320000,3.320000,3.320000,12.789212,6.504027,3.320000,3.320000,6.182912,...,3.32,3.320000,3.320000,4.338456,3.771718,3.32,3.320000,3.320000,3.320000,6.100644
57,14.624502,11.918757,3.320000,4.292406,3.430485,10.728709,6.197159,3.320000,4.089918,5.201608,...,3.32,3.320000,3.320000,4.790134,3.320000,3.32,4.985474,4.444057,3.580523,6.301926
58,14.585190,11.090112,3.320000,3.674768,3.320000,12.877485,6.326960,3.320000,3.320000,3.320000,...,3.32,3.320000,3.320000,4.547342,3.320000,3.32,4.064473,3.320000,4.254152,5.964505
59,14.449554,10.805855,3.320000,3.320000,3.320000,12.660038,6.261395,3.320000,4.125096,3.320000,...,3.32,3.784260,3.644823,4.546974,3.427441,3.32,4.666265,3.888525,3.765754,5.452018


In [4]:
labels

Unnamed: 0,0
0,diagnosis: healthy control
1,diagnosis: healthy control
2,diagnosis: healthy control
3,diagnosis: healthy control
4,diagnosis: healthy control
...,...
56,omalizumab responder status: Responder
57,omalizumab responder status: Responder
58,omalizumab responder status: Responder
59,omalizumab responder status: Responder


In [5]:
encoder = LabelEncoder()
labels_encoded = encoder.fit_transform(labels.values)
labels_encoded_noncontrols = labels_encoded[labels_encoded!=0] - 1

data_controls = gene_expression_matrix[labels_encoded==0]
data = gene_expression_matrix[labels_encoded!=0]

gene_signature = data_controls.mean(axis=0)
data_scaled = data - gene_signature

scaler = MinMaxScaler((0, 1))
scaler.fit(data_scaled)
data_normalized = scaler.transform(data_scaled)

x_train = torch.FloatTensor(data_normalized)
y_train = torch.FloatTensor(labels_encoded_noncontrols).unsqueeze(1)
print(x_train.shape)
print(y_train.shape)

  return f(*args, **kwargs)


torch.Size([40, 28402])
torch.Size([40, 1])


In [6]:
torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x_train = x_train.to(device)
y_train = y_train.to(device)

layers = [
    torch.nn.Linear(x_train.size(1), 10, bias=False),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(10, 5, bias=False),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(5, 1, bias=False),
    torch.nn.Sigmoid(),
]
model = torch.nn.Sequential(*layers).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.00007)
model.train()
need_pruning = True
for epoch in range(1, 30001):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = torch.nn.functional.binary_cross_entropy(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.008 * torch.norm(module.weight, 1)

    # backward pass
    loss.backward()
    optimizer.step()

    # compute accuracy
    if epoch % 1000 == 0:
        y_pred_d = (y_pred > 0.5)
        accuracy = (y_pred_d.eq(y_train).sum(dim=1) == y_train.size(1)).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')
        
    if epoch > 8000 and need_pruning and epoch % 3000 == 0:
        prune_features(model, 1, device)
        need_pruning = True

Epoch 1000: train accuracy: 0.7500
Epoch 2000: train accuracy: 0.7500
Epoch 3000: train accuracy: 0.9750
Epoch 4000: train accuracy: 0.7750
Epoch 5000: train accuracy: 1.0000
Epoch 6000: train accuracy: 1.0000
Epoch 7000: train accuracy: 0.9750
Epoch 8000: train accuracy: 1.0000
Epoch 9000: train accuracy: 1.0000
Epoch 10000: train accuracy: 0.8250
Epoch 11000: train accuracy: 0.8750
Epoch 12000: train accuracy: 0.9250
Epoch 13000: train accuracy: 0.9500
Epoch 14000: train accuracy: 0.9500
Epoch 15000: train accuracy: 1.0000
Epoch 16000: train accuracy: 1.0000
Epoch 17000: train accuracy: 0.9750
Epoch 18000: train accuracy: 0.9750
Epoch 19000: train accuracy: 0.9500
Epoch 20000: train accuracy: 0.9500
Epoch 21000: train accuracy: 0.9750
Epoch 22000: train accuracy: 0.9500
Epoch 23000: train accuracy: 0.9500
Epoch 24000: train accuracy: 0.9250
Epoch 25000: train accuracy: 0.9250
Epoch 26000: train accuracy: 0.9250
Epoch 27000: train accuracy: 0.9250
Epoch 28000: train accuracy: 0.9250
E

## Local explanations

In [7]:
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, (xin, yin) in enumerate(zip(x_train, y_train)):
    model_reduced = get_reduced_model(model, xin.to(device), bias=False).to(device)
    for module in model_reduced.children():
        if isinstance(module, torch.nn.Linear):
            wa = module.weight.cpu().detach().numpy()
            break
    output = model_reduced(xin)
    
    pred_class = torch.argmax(output)
    true_class = torch.argmax(y_train[i])

    # generate local explanation only if the prediction is correct
    if pred_class.eq(true_class):
        local_explanation = logic.relu_nn.explain_local(model.to(device), x_train, y_train, xin, yin, device=device)
        print(f'Input {(i+1)}')
        print(f'\tx={xin.cpu().detach().numpy()}')
        print(f'\ty={y_train[i].cpu().detach().numpy()}')
        print(f'\ty={output.cpu().detach().numpy()}')
        #print(f'\tw={wa}')
        print(f'\tExplanation: {local_explanation}')
        print()
    outputs.append(output)
    if i > 1:
        break

Input 1
	x=[0.94 0.31 0.   ... 0.26 0.8  0.68]
	y=[0.]
	y=[0.82]
	Explanation: ~feature0000006749 & ~feature0000013464 & ~feature0000015033 & ~feature0000025379

Input 2
	x=[0.84 0.71 0.   ... 0.26 0.   0.62]
	y=[0.]
	y=[0.5]
	Explanation: ~feature0000006749 & feature0000013464 & ~feature0000015033 & ~feature0000025379

Input 3
	x=[0.62 0.36 0.   ... 0.   0.   0.71]
	y=[0.]
	y=[0.5]
	Explanation: ~feature0000006749 & ~feature0000013464 & ~feature0000015033 & ~feature0000025379



# Combine local explanations

In [8]:
global_explanation, predictions, counter = logic.combine_local_explanations(model, x=x_train, y=y_train, 
                                                                          target_class=0, topk_explanations=10, 
                                                                          device=device)

ynp = y_train.cpu().detach().numpy()[:, 0]
accuracy = np.sum(predictions == ynp) / len(ynp)
print(f'Accuracy of when using the formula "{global_explanation}": {accuracy:.4f}')

Accuracy of when using the formula "~feature0000013464 | (~feature0000006749 & ~feature0000025379)": 0.7250


In [9]:
global_explanation = logic.relu_nn.explain_global(model, n_classes=1, target_class=0, device=device)
explanation = logic.relu_nn.explain_global(model, n_classes=1, target_class=0, device=device)
if explanation not in ['False', 'True', 'The formula is too complex!']:
    accuracy, _ = logic.relu_nn.test_explanation(explanation, target_class=0, x=x_train.cpu(), y=y_train.cpu())
    print(f'Class {0} - Global explanation: "{global_explanation}" - Accuracy: {accuracy:.4f}')

Class 0 - Global explanation: "feature0000006749 | feature0000025379 | (feature0000015033 & ~feature0000013464)" - Accuracy: 0.9250


In [10]:
w, b = collect_parameters(model, device)
feature_weights = w[0]
feature_used_bool = np.sum(np.abs(feature_weights), axis=0) > 0
feature_used = np.nonzero(feature_used_bool)[0]
genes.iloc[feature_used]

Unnamed: 0,0
6749,ILMN_1708983
13464,ILMN_1775520
15033,ILMN_1791569
25379,ILMN_3228700


ILMN_3286286, ILMN_1775520, ILMN_1656849, ILMN_1781198, ILMN_1665457

In [11]:
sum(y_train == 0).item() / len(y_train)

0.25