In [2]:
import numpy as np
import os
import struct

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

def load_mnist(image_path, label_path):
    images = read_idx(image_path)
    labels = read_idx(label_path)
    return images, labels

train_image_path = './MNIST/train-images-idx3-ubyte/train-images-idx3-ubyte'
train_label_path = './MNIST/train-labels-idx1-ubyte/train-labels-idx1-ubyte'
test_image_path =  './MNIST/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte'
test_label_path =  './MNIST/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte'

In [7]:
train_images, train_labels = load_mnist(train_image_path, train_label_path)
test_images, test_labels = load_mnist(test_image_path, test_label_path)
print(f'Train images shape: {train_images.shape}')
print(f'Train labels shape: {train_labels.shape}')
print(f'Test images shape: {test_images.shape}')
print(f'Test labels shape: {test_labels.shape}')

Train images shape: (60000, 28, 28)
Train labels shape: (60000,)
Test images shape: (10000, 28, 28)
Test labels shape: (10000,)


In [None]:
from extractor.avg import AVG
from extractor.resnet import ResNet
from torch_explain.nn.concepts import ConceptReasoningLayer, ConceptEmbedding
import torch.nn.functional as F
import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
    x, c, y, test_size=0.33, random_state=42)

x_train = x_train[:10]
x_test = x_test[:10]
c_train = c_train[:10]
c_test = c_test[:10]
y_train = y_train[:10]
y_test = y_test[:10]

embedding_size = 8
concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    te.nn.ConceptEmbedding(10, c.shape[1], embedding_size),
)


c_emb, c_pred = concept_encoder.forward(x_test)


# -------------------------------------#

y_train = F.one_hot(y_train.long().ravel()).float()
y_test = F.one_hot(y_test.long().ravel()).float()


task_predictor = ConceptReasoningLayer(
    embedding_size, y_train.shape[1], log=True)
model = torch.nn.Sequential(concept_encoder, task_predictor)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form = torch.nn.BCELoss()
model.train()
for epoch in range(1):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb, c_pred)

    # compute loss
    concept_loss = loss_form(c_pred, c_train)
    task_loss = loss_form(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

local_explanations = task_predictor.explain(c_emb, c_pred, 'local')
global_explanations = task_predictor.explain(c_emb, c_pred, 'global')

# print(local_explanations)
print(global_explanations)


# task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1])
# model = torch.nn.Sequential(concept_encoder, task_predictor)

# optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
# loss_form = torch.nn.BCELoss()
# model.train()
# for epoch in range(501):
#     optimizer.zero_grad()

#     # generate concept and task predictions
#     c_emb, c_pred = concept_encoder(x_train)
#     y_pred = task_predictor(c_emb, c_pred)

#     # compute loss
#     concept_loss = loss_form(c_pred, c_train)
#     task_loss = loss_form(y_pred, y_train)
#     loss = concept_loss + 0.5*task_loss

#     loss.backward()
#     optimizer.step()

# local_explanations = task_predictor.explain(c_emb, c_pred, 'local')
# global_explanations = task_predictor.explain(c_emb, c_pred, 'global')

# # print(local_explanations)
# print(global_explanations)


In [4]:
train_images_flat = train_images.reshape(train_images.shape[0], -1)
test_images_flat = test_images.reshape(test_images.shape[0], -1)

X_train = train_images_flat
y_train = train_labels
X_test = test_images_flat
y_test = test_labels

In [6]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

((60000, 784), (60000,), (10000, 784), (10000,))