In [1]:
import sys
sys.path.append('..')
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sympy import simplify_logic

from deep_logic import validate_network, prune_equal_fanin, collect_parameters
from deep_logic import fol

torch.manual_seed(0)
np.random.seed(0)

In [11]:
x = pd.read_csv('dsprites_c_train.csv', index_col=0)
y = pd.read_csv('dsprites_y_train.csv', index_col=0)
x_train = torch.tensor(x.values, dtype=torch.float)
y_train = torch.tensor(y.values, dtype=torch.float)[:, 0].unsqueeze(1)
x_test = x_train
print(x_train.shape)
print(y_train.shape)
x

torch.Size([5530, 50])
torch.Size([5530, 1])


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5525,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5526,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5527,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5528,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [12]:
layers = [
    torch.nn.Linear(x_train.size(1), 50, bias=False),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 30, bias=False),
    torch.nn.ReLU(),
    torch.nn.Linear(30, 1, bias=False),
    torch.nn.Sigmoid(),
]
model = torch.nn.Sequential(*layers)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
need_pruning = True
for epoch in range(1000):
    # forward pass
    optimizer.zero_grad()
    y_pred = model(x_train)
    # Compute Loss
    loss = torch.nn.functional.mse_loss(y_pred, y_train)

    for module in model.children():
        if isinstance(module, torch.nn.Linear):
            loss += 0.0001 * torch.norm(module.weight, 1)

    # backward pass
    loss.backward()
    optimizer.step()

    # compute accuracy
    if epoch % 100 == 0:
        y_pred_d = (y_pred > 0.5)
        accuracy = (y_pred_d.eq(y_train).sum(dim=1) == y_train.size(1)).sum().item() / y_train.size(0)
        print(f'Epoch {epoch}: train accuracy: {accuracy:.4f}')

Epoch 0: train accuracy: 0.2873
Epoch 100: train accuracy: 1.0000
Epoch 200: train accuracy: 1.0000
Epoch 300: train accuracy: 1.0000
Epoch 400: train accuracy: 1.0000
Epoch 500: train accuracy: 1.0000
Epoch 600: train accuracy: 1.0000
Epoch 700: train accuracy: 1.0000
Epoch 800: train accuracy: 1.0000
Epoch 900: train accuracy: 1.0000


In [13]:
def reduced_model(xi, w, b):
    w2, b2 = [], []
    xin = xi.detach().numpy()
    # for i, (wi, bi) in enumerate(zip(np.copy(w), np.copy(b))):
    for i, wi in enumerate(np.copy(w)):
        # if i == 0:
        #     wi = wi[:, best_features_sorted]
        hi = np.matmul(wi, xin) #+ bi
        ai = np.max([np.zeros(len(hi),), hi], axis=0)
        wi2 = np.copy(wi)
        # bi2 = np.copy(bi)
        wi2[ai==0] = 0
        # bi2[ai==0] = 0
        w2.append(wi2)
        # b2.append(bi2)
        hi2 = np.matmul(wi2, xin) #+ bi2
        ai2 = np.max([np.zeros(len(hi2),), hi2], axis=0)
        assert np.all(ai == ai2)
        xin = ai2

    # for i, (wi, bi) in enumerate(zip(np.copy(w2), np.copy(b2))):
    for i, wi in enumerate(np.copy(w2)):
        if i == 0:
            wa = wi
            # ba = bi
        else:
            wa = np.matmul(wi, wa)
            # ba = np.matmul(wi, ba) + bi

    xin = xi
    output = np.matmul(wa, xin) #+ ba
    output = output > 0
    return output, wa#, ba

In [14]:
w, b = collect_parameters(model)

# Local explanations

In [60]:
def generate_local_explanations(w, xi):
    w_max = np.max(np.abs(w))
    if w_max > 0:
        w2 = w / w_max
        w2[w2>0.5] = 1
        w2[w2<-0.5] = -1
        w2[np.abs(w2)!=1] = 0
    else:
        return 'False', w, 0*w
    explanation = ''
    for i, wi in enumerate(w2):
        if wi==1 or wi==-1:
            if xi[i] > 0.5:
                if explanation:
                    explanation += ' & '
                explanation += f'f{i}'
            else:
                if explanation:
                    explanation += ' & '
                explanation += f'~f{i}'
    return explanation, w/w_max, w2

wa = np.array([0.49, 0.95])
xin = torch.tensor([1, 0], dtype=torch.float)
print(generate_local_explanations(wa, xin))

('f0 & ~f1', array([0.52, 1.  ]), array([1., 1.]))


In [61]:
np.set_printoptions(precision=2, suppress=True)
outputs = []
for i, xin in enumerate(x_test):
    output, wa = reduced_model(xin, w, b)
    local_explanation, w_norm, w_bool = generate_local_explanations(wa[0], xin)
    if output:
        print(f'Input {(i+1)}')
        print(f'\tx={xin.detach().numpy()}')
        print(f'\ty={output.detach().numpy()}')
        print(f'\tw={wa}')
        print(f'\tw_norm={w_norm}')
        print(f'\tw_bool={w_bool}')
        print(f'\tExplanation: {local_explanation}')
        print()
    outputs.append(output)

Input 1
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0

Input 674
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. 

Input 1277
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

Input 1887
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

Input 2487
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

Input 3067
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 1.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

Input 3647
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

Input 4270
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

Input 4903
	x=[1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0.]
	y=[ True]
	w=[[-7.11  2.43 -5.85 -5.81  8.57 -0.   -0.01 -0.02 -0.   -0.03 -0.01 -0.
   0.01 -0.01 -0.01  0.01 -0.   -0.01 -0.02  0.   -0.01  0.01 -0.01 -0.02
  -0.01 -0.01 -0.    0.02  0.01 -0.01  0.    0.    0.01 -0.    0.01 -0.01
   0.02  0.01 -0.   -0.01 -0.02 -0.    0.01 -0.   -0.   -0.02 -0.02 -0.01
  -0.02 -0.  ]]
	w_norm=[-0.83  0.28 -0.68 -0.68  1.   -0.   -0.   -0.   -0.   -0.   -0.   -0.
  0.   -0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.   -0.
 -0.   -0.   -0.    0.    0.   -0.    0.    0.    0.   -0.    0.   -0.
  0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.   -0.   -0.   -0.
 -0.   -0.  ]
	w_bool=[-1.  0. -1. -1.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

# Combine local explanations

In [63]:
global_explanation = ''
for xin in x_test:
    output, wa = reduced_model(xin, w, b)
    if output == 1:
        local_explanation, _, _ = generate_local_explanations(wa[0], xin)
        global_explanation += f'{local_explanation} | '
global_explanation = global_explanation[:-3]

In [64]:
simplify_logic(global_explanation, 'dnf')

f0 & f4 & ~f2 & ~f3

# Evaluate generality of global explanation

In [65]:
yi = y.iloc[:, 0]
outputs = np.array(outputs)
x_false = x[outputs==0]
x_true = x[outputs==1]
print(x_false.shape)
print(x_true.shape)
x_true

(5209, 50)
(321, 50)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
13,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
21,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
36,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5433,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
5440,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5483,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5484,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [66]:
# f0 & ~f2 & ~f3 & f4
x_bool = x.astype('bool')
mask = x_bool.iloc[:, 0] & ~x_bool.iloc[:, 2] & ~x_bool.iloc[:, 3] & x_bool.iloc[:, 4]
sum(mask)

321

In [76]:
preds = 0 * yi
preds[mask] = 1
accuracy = sum(preds == yi) / len(yi)
print(f'Accuracy: {accuracy}')

Accuracy: 1.0
