In [1]:
import torch.nn.functional as F
from torch_explain.nn.concepts import ConceptReasoningLayer, ConceptReasoningLayerMod
import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(3000)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
    x, c, y, test_size=0.33, random_state=42)

embedding_size = 8
concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    te.nn.ConceptEmbedding(10, c.shape[1], embedding_size),
)


In [2]:
y_train.shape

torch.Size([2010, 1])

In [3]:

# -------------------------------------#

y_train = F.one_hot(y_train.long().ravel()).float()
y_test = F.one_hot(y_test.long().ravel()).float()

In [4]:

task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1])
model = torch.nn.Sequential(concept_encoder, task_predictor)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.BCELoss()
model.train()
for epoch in range(601):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb, c_pred)
    

    # compute loss
    concept_loss = loss_form(c_pred, c_train)
    task_loss = loss_form(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

local_explanations = task_predictor.explain(c_emb, c_pred, 'local')
global_explanations = task_predictor.explain(c_emb, c_pred, 'global')

# print(local_explanations)
print(global_explanations)


[{'class': 'y_0', 'explanation': 'c_0 & c_1', 'count': 502}, {'class': 'y_0', 'explanation': '~c_0 & ~c_1', 'count': 511}, {'class': 'y_1', 'explanation': 'c_0 & ~c_1', 'count': 482}, {'class': 'y_1', 'explanation': '~c_0 & c_1', 'count': 515}]


In [5]:
gex0 = [item['explanation'] for item in global_explanations if item['class'] == 'y_0']
gex1 = [item['explanation'] for item in global_explanations if item['class'] == 'y_1']

In [9]:
gex0

['c_0 & c_1', '~c_0 & ~c_1']

In [6]:
aggregated_exp0 = "|".join(gex0)
aggregated_exp1 = "|".join(gex1)

In [7]:
from sympy.logic.boolalg import to_dnf
to_dnf(aggregated_exp0, simplify=True)

(c_0 & c_1) | (~c_0 & ~c_1)

In [8]:
to_dnf(aggregated_exp1, simplify=True)

(c_0 & ~c_1) | (c_1 & ~c_0)