# Soft-GCExplainer

In [1]:
import os,sys

parent = os.path.dirname(os.getcwd())
if parent not in sys.path:
    sys.path.append(parent)
print(parent)
%matplotlib widget

d:\Documents\git\XAI-Cancer-Diagnosis\XAI-Cancer-Diagnosis


---

In [2]:
notebook_settings = {'evaluate_model':False, 'save_concepts': False, 'tsne':False}

### Model & Data

In [17]:
# Data

from src.datasets.BACH import BACH
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Compose, KNNGraph
import torch
import numpy as np


graph_aug_val = Compose([KNNGraph(k=6)])


print('CUDA available: ', torch.cuda.is_available())
print('CUDA device count: ', torch.cuda.device_count())

src_folder = os.path.join(
    "C://Users", "aless", "Documents", "FtT", "data", "BACH_TRAIN")
tid,vid = BACH.get_train_val_ids(src_folder,"graph_ind_FtT_19_11_1.txt")
train_set = BACH(src_folder,graph_augmentation=graph_aug_val,ids=tid,pre_encoded=True)
val_set = BACH(src_folder,graph_augmentation=graph_aug_val,ids=vid,pre_encoded=True)

train_loader = DataLoader(train_set, batch_size=4, shuffle=False)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False)

print(len(train_loader))
print(len(val_loader))

CUDA available:  True
CUDA device count:  1
77
20


In [4]:
# Model

from src.deep_learning.architectures.cancer_prediction.cancer_gnn import CancerGNN

model = CancerGNN.load_from_checkpoint(os.path.join(parent,"experiments","checkpoints", "FtT_FtT_19_11_1.ckpt"),WIDTH=64,HEIGHT=7)

print(model)
print(model.predictor)

CancerGNN(
  (gnn): GCNx(
    (conv): ModuleList(
      (0): GCNConv(64, 64)
      (1): GCNConv(64, 64)
      (2): GCNConv(64, 64)
      (3): GCNConv(64, 64)
      (4): GCNConv(64, 64)
      (5): GCNConv(64, 64)
      (6): GCNConv(64, 64)
    )
    (transform): ModuleList(
      (0): Sequential(
        (0): BatchNorm1d(312, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=312, out_features=64, bias=True)
      )
      (1): Sequential(
        (0): BatchNorm1d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Dropout(p=0, inplace=False)
        (3): Linear(in_features=64, out_features=64, bias=True)
      )
      (2): Sequential(
        (0): BatchNorm1d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Dropout(p=0, inplace=False)
        (3): Linear(in_features=64, out_feat

---

### Get Raw Activations

In [5]:
model_width = model.width
print("Model Width: ",model_width)

Model Width:  64


In [30]:
from tqdm import tqdm
from torch import softmax

class RawActivationHook:
    def __init__(self,width,device='cuda'):
        self.activations = torch.zeros(0,width).to(device)
    
    def append_activations(self, model,input, output):
        self.activations = torch.cat((self.activations,output),dim=0)
        
    def __enter__(self):
        self.remove_handle = model.gnn.conv[-1].register_forward_hook(lambda m,i,o:(self.append_activations(m,i,o)))
        return self
    
    def __exit__(self, type, value, traceback):
        self.remove_handle.remove()
        
    def get_activations(self):
        return self.activations

class ModelForward:
    def __init__(self,model):
        self.model = model
        self.model_width  = model.width
        self.model.eval()
        self.model.to('cuda')
        
    def forward_node_level(self,loader):
        # Acquire the activations
        hook = RawActivationHook(self.model_width)


        
        predictions = torch.zeros(0).to('cuda')
        ground = torch.zeros(0).to('cuda')
        
        with hook:
            for batch in tqdm(loader):
                batch = batch.to('cuda')
                output = model.forward(batch.x,batch.edge_index,batch.batch)
                
                node_probs = softmax(output,dim=1)[batch.batch]
                node_ground = batch.y[batch.batch]
                
                
                predictions = torch.cat([predictions,node_probs])
                ground = torch.cat([ground,node_ground])
            
        raw_activations = hook.get_activations()
        return raw_activations.cpu().detach().numpy(),predictions.cpu().detach().numpy(),ground.cpu().detach().numpy().astype(int)
    
    def forward_graph_level(self,loader):
        predictions = torch.zeros(0).to('cuda')
        ground = torch.zeros(0).to('cuda')
        
        for batch in tqdm(loader):
            batch = batch.to('cuda')
            output = model.forward(batch.x,batch.edge_index,batch.batch)
            
            graph_probs = softmax(output,dim=1)
            graph_ground = batch.y
            
            predictions = torch.cat([predictions,graph_probs])
            ground = torch.cat([ground,graph_ground])
            
        return predictions.cpu().detach().numpy(),ground.cpu().detach().numpy().astype(int)

forwarder = ModelForward(model)

graph_prob_predictions_train,graph_ground_train = forwarder.forward_graph_level(train_loader)
graph_prob_predictions_val,graph_ground_val = forwarder.forward_graph_level(val_loader)

activations_train,node_predictions_train,node_ground_train  = forwarder.forward_node_level(train_loader)
activations_val,node_predictions_val,node_ground_val = forwarder.forward_node_level(val_loader)

100%|██████████| 77/77 [00:16<00:00,  4.66it/s]
100%|██████████| 20/20 [00:04<00:00,  4.71it/s]
100%|██████████| 77/77 [00:15<00:00,  4.94it/s]
100%|██████████| 20/20 [00:04<00:00,  4.46it/s]


In [33]:
graph_predictions_train = np.argmax(graph_prob_predictions_train,axis=1)
graph_predictions_val = np.argmax(graph_prob_predictions_val,axis=1)


In [34]:
print(f"Model accuracy on validation set: {np.sum(graph_predictions_val == graph_ground_val)/len(graph_ground_val)}")
print(f"Model accuracy on training set: {np.sum(graph_predictions_train == graph_ground_train)/len(graph_ground_train)}")

Model accuracy on validation set: 0.7662337662337663
Model accuracy on training set: 0.935064935064935


---
# K Means

In [35]:
import numpy as np
from sklearn.mixture import GaussianMixture

class ConceptDiscoverer:
    def __init__(self,num_concepts,verbose=True) -> None:
        self.k = num_concepts
        self.gm = GaussianMixture(n_components=self.k,verbose=verbose)
        
    def fit(self,activations: np.ndarray):
        self.mu = activations.mean(axis=0)
        self.sigma = activations.std(axis=0)
        
        whitten_activations = ConceptDiscoverer.whiten(activations,self.mu,self.sigma)
        
        self.gm.fit(whitten_activations)
    
    def predict(self,activations):
        whitten_activations = ConceptDiscoverer.whiten(activations,self.mu,self.sigma)
        return self.gm.predict(whitten_activations)
        
    def predict_proba(self,activations):
        whitten_activations = ConceptDiscoverer.whiten(activations,self.mu,self.sigma)
        return self.gm.predict_proba(whitten_activations)
        
        
    @staticmethod
    def whiten(obs,mu,sigma):
        return (obs - mu)/sigma


In [36]:
sub_sample_size = 10000
sub_sample = np.random.choice(activations_train.shape[0],sub_sample_size,replace=False)

num_concepts =64

cd = ConceptDiscoverer(num_concepts=num_concepts)
cd.fit(activations_train[sub_sample,:])

Initialization 0
  Iteration 10
  Iteration 20
  Iteration 30
Initialization converged: True


In [37]:
train_concept_probs = cd.predict_proba(activations_train)
val_concept_probs = cd.predict_proba(activations_val)

## Rejection Sampling
What if we consider only those activations that are we have a certain degree of confidence is in that concept?

In [38]:
def within_acceptance(probabilities:np.ndarray,threshold = 0.9):
    max_prob = probabilities.max(axis=1)
    assert(max_prob.shape[0] == probabilities.shape[0])
    number_surpass = (max_prob > threshold).sum()
    return number_surpass / probabilities.shape[0]

print("Train: ",within_acceptance(train_concept_probs))
print("Val: ",within_acceptance(val_concept_probs))

Train:  0.9669711264771256
Val:  0.9590673302261895


In [39]:
if notebook_settings['save_concepts']:
    np.save("mu.npy",mu)
    np.save("sigma.npy",sigma)

---
# Vizualize Explanations

In [40]:

from sklearn.manifold import TSNE


if notebook_settings['tsne']:
    tsne_sample_size  = 50000
    tsne_sample= np.random.choice(train_obs_white.shape[0],tsne_sample_size,replace=False)
    position_tsne_sample =TSNE(n_components=2,verbose=2,perplexity = 1000,n_iter=400).fit_transform(train_obs_white[tsne_sample])


KeyError: 'tsne'

In [None]:

import matplotlib.pyplot as plt
from matplotlib import cm

def plot_tsne_points_with_category(points,classes):
    f = plt.figure(figsize=(10,10))
    # How many different classes are there
    unique_classes = np.unique(classes)
    colours = cm.rainbow(np.linspace(0,1,len(unique_classes)))
    
    # Group the points by their class and then plot all points of the same class at a time
    for i,cl in enumerate(unique_classes):
        points_with_class = (classes==cl)
        v = points[points_with_class]
        x,y = v[:,0],v[:,1]
        plt.scatter(x,y,color=colours[i])
    plt.show()

In [None]:
if notebook_settings['tsne']:
    prediction_tsne_sample = train_predictions[tsne_sample]
    ground_tsne_sample = train_ground[tsne_sample]
    hard_concept_tsne_sample = train_probs[tsne_sample].argmax(axis=1)

    plot_tsne_points_with_category(position_tsne_sample,prediction_tsne_sample)
    plot_tsne_points_with_category(position_tsne_sample,ground_tsne_sample)
    plot_tsne_points_with_category(position_tsne_sample,hard_concept_tsne_sample)

NameError: name 'train_predictions' is not defined

---
# Concept Completeness

In [None]:
def one_hot_concept(concept_ids,num_concepts):
    one_hot = np.zeros((len(concept_ids),num_concepts))
    one_hot[np.arange(len(concept_ids)),concept_ids] = 1
    return one_hot

In [None]:

def concept_vs_prediction(graph,model,gaussian_mixture):
    # Get prediction and activation
    global raw_activations
    raw_activations = torch.zeros(0,model_width)
    prediction = model.predict(graph).argmax().item()
    # Get the concepts
    concept_probs = gaussian_mixture.predict_proba(whiten(raw_activations.detach().numpy()))
    return prediction,concept_probs

# Takes every observed concept along with the eventual graph prediction and then generates X,Y pairs for training a classifier
def generate_concept_to_prediction_map(model,gaussian_mixture,loader,hard=False, ignore_below=0.99):
    X = np.zeros((0,num_concepts))
    Y = np.zeros(0)
    for graph in tqdm(loader):
        # assert only one graph in batch
        assert graph.batch.max() == 0
        prediction,concepts = concept_vs_prediction(graph,model,gaussian_mixture)
        if hard:
            best_concept = concepts.argmax(axis=1)
            concepts = one_hot_concept(best_concept,num_concepts)
        concepts = concepts[concepts.max(axis=1) > ignore_below]
        X = np.concatenate((X,concepts))
        Y = np.concatenate((Y,np.zeros(concepts.shape[0])+prediction))
    return X,Y
    
    
X_train,Y_train = generate_concept_to_prediction_map(model,gm,train_loader,True)
X_val,Y_val = generate_concept_to_prediction_map(model,gm,val_loader,True)

  0%|          | 0/308 [00:00<?, ?it/s]

100%|██████████| 308/308 [00:40<00:00,  7.58it/s]
100%|██████████| 77/77 [00:08<00:00,  9.04it/s]


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Create a decision tree classifier and fit it to the training data
clf = DecisionTreeClassifier()
clf.fit(X_train,Y_train)


# Now clf can be used to make predictions on unseen data
# For example, to predict the class of the samples in the test set, you can use:
Y_pred = clf.predict(X_val)


In [None]:
def get_concept_prediction_prob(concept_id,num_concepts,dt):
    # if concept_id is an int, convert to np array
    if isinstance(concept_id,int):
        concept_id = np.array([concept_id])
    # Convert to one hot
    one_hot = one_hot_concept(concept_id,num_concepts)
    # Predict
    return dt.predict_proba(one_hot)

In [None]:
# Get prediction from one hot of each concept
def print_concept_completeness(concept_id, num_concepts,dt):
    probs = get_concept_prediction_prob(concept_id,num_concepts,dt)
    # Turn to percentages
    probs = (probs*100).round(2).max()
    print(f"For concept id {concept_id}, the predictive power is: {probs}")
    
    
for i in range(num_concepts):
   # print_concept_completeness(i,num_concepts,clf)
    pass 
# Print avergae predictive power
print("Average predictive power: ",(clf.predict(X_train) == Y_train).sum() / Y_train.shape[0])

For concept id 0, the predictive power is: 45.75
For concept id 1, the predictive power is: 94.53
For concept id 2, the predictive power is: 49.95
For concept id 3, the predictive power is: 56.39
For concept id 4, the predictive power is: 67.42
For concept id 5, the predictive power is: 55.4
For concept id 6, the predictive power is: 80.87
For concept id 7, the predictive power is: 84.55
For concept id 8, the predictive power is: 54.8
For concept id 9, the predictive power is: 75.97
For concept id 10, the predictive power is: 48.79
For concept id 11, the predictive power is: 82.91
For concept id 12, the predictive power is: 53.38
For concept id 13, the predictive power is: 54.37
For concept id 14, the predictive power is: 66.09
For concept id 15, the predictive power is: 50.6
For concept id 16, the predictive power is: 57.85
For concept id 17, the predictive power is: 32.09
For concept id 18, the predictive power is: 91.73
For concept id 19, the predictive power is: 83.79
For concept i

Q: For a particular graph, what concepts are present, how many times for each, and what is the contribution of that concept

In [None]:
def get_concept_distribution_from_graph(graph,model,gaussian_mixture,inclusion_threshold = 0.9):
    _,concepts = concept_vs_prediction(graph,model,gaussian_mixture)
    
    # Count the number of times each concept is present above a certain threshold
    most_popular_concept = concepts.argmax(axis=1)
    one_hot_concepts = one_hot_concept(most_popular_concept,num_concepts)
    
    concept_count = np.zeros((num_concepts))
    reject_below_threshold = concepts.max(axis=1) > inclusion_threshold
    
    concept_count = one_hot_concepts*reject_below_threshold.reshape(-1,1)
    return concept_count.sum(axis=0)

def weighted_predictor(concept_dist,concept_probs):
    # Take a weighted sum of the concept_probs with the concept_dist
    weighted_sum = concept_probs*concept_dist.reshape(-1,1)
    return weighted_sum.sum(axis=0)

def sum_predictor(concept_dist,concept_probs):
    # If there is a concept present, include its prob
    concept_present = concept_dist > 0
    concept_probs = concept_probs*concept_present.reshape(-1,1)
    # Normalize
    return concept_probs.sum(axis=0) / concept_present.sum()

    
    
def concept_vote_graph(graph,model,gaussian_mixture,decision_tree,predictor_rule = weighted_predictor,inclusion_threshold = 0.999):
    concept_dist = get_concept_distribution_from_graph(graph,model,gaussian_mixture,inclusion_threshold)
    concept_dist = concept_dist / concept_dist.sum()
    
    # Take a weighted sum of the concept_probs with the concept_dist
    concept_probs = get_concept_prediction_prob(np.arange(num_concepts),num_concepts,decision_tree)
    
    return predictor_rule(concept_dist,concept_probs)
    
    
random_graph_id = np.random.randint(len(val_loader))
graph = train_loader.dataset[random_graph_id]
graph_ground = graph.y.item()
graph_prediction = model.predict_proba(graph)
graph_concept_vote = concept_vote_graph(graph,model,gm,clf)

print("Ground: ",graph_ground)
print("Prediction: ",graph_prediction.argmax().item(),graph_prediction)
print("Concept vote: ",graph_concept_vote.argmax(),graph_concept_vote)


Ground:  0
Prediction:  0 tensor([[0.9620, 0.0097, 0.0182, 0.0101]])
Concept vote:  0 [0.52333928 0.08914891 0.27693084 0.11058097]


In [None]:
# Import cross entropy and accuracy
from sklearn.metrics import accuracy_score,log_loss

# Across a whole dataset, let us evaluate how good the concept_voter is


def evaluate_concept_voter(model,gaussian_mixture,decision_tree,loader,predictor_rule = weighted_predictor,inclusion_threshold = 0.999):
    # Get the predictions
    predictions = np.zeros((0,4))
    concept_votes = np.zeros((0,4))
    ground = np.zeros((0))
    for graph in tqdm(loader):
        ground = np.concatenate((ground,graph.y.detach().numpy()))
        prediction = model.predict_proba(graph)
        concept_vote = concept_vote_graph(graph,model,gaussian_mixture,decision_tree,predictor_rule,inclusion_threshold)
        predictions = np.concatenate((predictions,prediction))
        concept_votes = np.concatenate((concept_votes,concept_vote.reshape(1,-1)))
    
    # Caluclate cross entropy loss 
    print("Accuracy: ",accuracy_score(ground,predictions.argmax(axis=1)))
    print("Concept voter accuracy: ",accuracy_score(ground,concept_votes.argmax(axis=1)))

#
#for pred_name, predictor in zip(["Weighted","Sum"],[weighted_predictor,sum_predictor]):
#    for threshold in [0,0.9,0.99,0.999,0.9999]:
#        print(f"Evaluating {pred_name} predictor with threshold {threshold}")
#        evaluate_concept_voter(model,gm,clf,train_loader,predictor,inclusion_threshold=threshold)
#        print("")


Evaluating Weighted predictor with threshold 0


100%|██████████| 308/308 [00:39<00:00,  7.85it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.827922077922078

Evaluating Weighted predictor with threshold 0.9


100%|██████████| 308/308 [00:38<00:00,  8.06it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.8311688311688312

Evaluating Weighted predictor with threshold 0.99


100%|██████████| 308/308 [00:40<00:00,  7.64it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.827922077922078

Evaluating Weighted predictor with threshold 0.999


100%|██████████| 308/308 [00:39<00:00,  7.81it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.827922077922078

Evaluating Weighted predictor with threshold 0.9999


100%|██████████| 308/308 [00:41<00:00,  7.50it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.827922077922078

Evaluating Sum predictor with threshold 0


100%|██████████| 308/308 [00:39<00:00,  7.83it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.7337662337662337

Evaluating Sum predictor with threshold 0.9


100%|██████████| 308/308 [00:40<00:00,  7.52it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.737012987012987

Evaluating Sum predictor with threshold 0.99


100%|██████████| 308/308 [00:40<00:00,  7.67it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.75

Evaluating Sum predictor with threshold 0.999


100%|██████████| 308/308 [00:39<00:00,  7.71it/s]


Accuracy:  0.9058441558441559
Concept voter accuracy:  0.7727272727272727

Evaluating Sum predictor with threshold 0.9999


100%|██████████| 308/308 [00:38<00:00,  8.02it/s]

Accuracy:  0.9058441558441559
Concept voter accuracy:  0.788961038961039






# Determine K
- Perform GM
- Get Concept Completeness
- Plot

---
### CBE