# Soft-GCExplainer

In [15]:
import os,sys

parent = os.path.dirname(os.getcwd())
if parent not in sys.path:
    sys.path.append(parent)
print(parent)
%matplotlib widget

d:\Documents\git\XAI-Cancer-Diagnosis\XAI-Cancer-Diagnosis


---

In [16]:
notebook_settings = {'evaluate_model':False, 'save_concepts': False, 'tsne':False}

### Model & Data

In [17]:
# Data

from src.datasets.BACH import BACH
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Compose, KNNGraph
import torch
import numpy as np


graph_aug_val = Compose([KNNGraph(k=6)])


print('CUDA available: ', torch.cuda.is_available())
print('CUDA device count: ', torch.cuda.device_count())

src_folder = os.path.join(
    "C://Users", "aless", "Documents", "FtT", "data", "BACH_TRAIN")
tid,vid = BACH.get_train_val_ids(src_folder,"graph_ind_FtT_19_11_dropout_3.txt")
train_set = BACH(src_folder,graph_augmentation=graph_aug_val,ids=tid,pre_encoded=True)
val_set = BACH(src_folder,graph_augmentation=graph_aug_val,ids=vid,pre_encoded=True)

train_loader = DataLoader(train_set, batch_size=4, shuffle=False)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False)

print(len(train_loader))
print(len(val_loader))

CUDA available:  True
CUDA device count:  1
78
19


In [18]:
# Model

from src.deep_learning.architectures.cancer_prediction.cancer_gnn import CancerGNN

model = CancerGNN.load_from_checkpoint(os.path.join(parent,"experiments","checkpoints", "FtT_FtT_19_11_dropout_3.ckpt"),WIDTH=64,HEIGHT=7)

print(model)
print(model.predictor)

CancerGNN(
  (gnn): GCNx(
    (conv): ModuleList(
      (0): GCNConv(64, 64)
      (1): GCNConv(64, 64)
      (2): GCNConv(64, 64)
      (3): GCNConv(64, 64)
      (4): GCNConv(64, 64)
      (5): GCNConv(64, 64)
      (6): GCNConv(64, 64)
    )
    (transform): ModuleList(
      (0): Sequential(
        (0): BatchNorm1d(312, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=312, out_features=64, bias=True)
      )
      (1): Sequential(
        (0): BatchNorm1d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Dropout(p=0, inplace=False)
        (3): Linear(in_features=64, out_features=64, bias=True)
      )
      (2): Sequential(
        (0): BatchNorm1d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Dropout(p=0, inplace=False)
        (3): Linear(in_features=64, out_feat

In [45]:
from torch_geometric.data import Batch, Data

def loader_from_one_graph(graph:Data):
    return DataLoader([graph], batch_size=1, shuffle=False)

def extract_graph(batch: Batch, graph_idx: int):
    # Find node indices for the graph
    node_mask = batch.batch == graph_idx

    # Extract the node features for the graph
    x = batch.x[node_mask]

    # Find edge indices for the graph
    edge_mask = node_mask[batch.edge_index[0]] & node_mask[batch.edge_index[1]]
    edge_index = batch.edge_index[:, edge_mask]

    # Re-map edge indices to the new node index space
    edge_index = edge_index - node_mask.nonzero(as_tuple=False).min()

    # If the batch contains other attributes, extract them similarly
    # ...
    y = batch.y[graph_idx]
    pos = batch.pos[node_mask]

    # Create a new Data object for the single graph
    single_graph = Data(x=x, edge_index=edge_index, y=y,pos=pos)
    
    return single_graph

def batch_to_graphs(batch:Batch):
    num_graphs = batch.batch.max().item()
    for ind in range(num_graphs):
        yield extract_graph(batch,ind)

---

### Get Raw Activations

In [20]:
model_width = model.width
print("Model Width: ",model_width)

Model Width:  64


In [21]:
from tqdm import tqdm
from torch import softmax

class RawActivationHook:
    def __init__(self,width,device='cuda'):
        self.activations = torch.zeros(0,width).to(device)
    
    def append_activations(self, model,input, output):
        self.activations = torch.cat((self.activations,output),dim=0)
        
    def __enter__(self):
        self.remove_handle = model.gnn.conv[-1].register_forward_hook(lambda m,i,o:(self.append_activations(m,i,o)))
        return self
    
    def __exit__(self, type, value, traceback):
        self.remove_handle.remove()
        
    def get_activations(self):
        return self.activations

class ModelForward:
    def __init__(self,model,verbose=False):
        self.model = model
        self.model_width  = model.width
        self.model.eval()
        self.model.to('cuda')
        self.verbose = verbose
        
    def forward_node_level(self,loader):
        # Acquire the activations
        hook = RawActivationHook(self.model_width)


        
        predictions = torch.zeros(0).to('cuda')
        ground = torch.zeros(0).to('cuda')
        
        with hook:
            for batch in tqdm(loader,disable=not self.verbose):
                batch = batch.to('cuda')
                output = model.forward(batch.x,batch.edge_index,batch.batch)
                
                node_probs = softmax(output,dim=1)[batch.batch]
                node_ground = batch.y[batch.batch]
                
                
                predictions = torch.cat([predictions,node_probs])
                ground = torch.cat([ground,node_ground])
            
        raw_activations = hook.get_activations()
        return raw_activations.cpu().detach().numpy(),predictions.cpu().detach().numpy(),ground.cpu().detach().numpy().astype(int)
    
    def forward_graph_level(self,loader):
        predictions = torch.zeros(0).to('cuda')
        ground = torch.zeros(0).to('cuda')
        
        for batch in tqdm(loader,disable=not self.verbose):
            batch = batch.to('cuda')
            output = model.forward(batch.x,batch.edge_index,batch.batch)
            
            graph_probs = softmax(output,dim=1)
            graph_ground = batch.y
            
            predictions = torch.cat([predictions,graph_probs])
            ground = torch.cat([ground,graph_ground])
            
        return predictions.cpu().detach().numpy(),ground.cpu().detach().numpy().astype(int)

forwarder = ModelForward(model)

graph_prob_predictions_train,graph_ground_train = forwarder.forward_graph_level(train_loader)
graph_prob_predictions_val,graph_ground_val = forwarder.forward_graph_level(val_loader)

activations_train,node_predictions_train,node_ground_train  = forwarder.forward_node_level(train_loader)
activations_val,node_predictions_val,node_ground_val = forwarder.forward_node_level(val_loader)

In [22]:
graph_predictions_train = np.argmax(graph_prob_predictions_train,axis=1)
graph_predictions_val = np.argmax(graph_prob_predictions_val,axis=1)


In [23]:
print(f"Model accuracy on validation set: {np.sum(graph_predictions_val == graph_ground_val)/len(graph_ground_val)}")
print(f"Model accuracy on training set: {np.sum(graph_predictions_train == graph_ground_train)/len(graph_ground_train)}")

Model accuracy on validation set: 0.7236842105263158
Model accuracy on training set: 0.8673139158576052


---
# K Means

In [24]:
import numpy as np
from sklearn.mixture import GaussianMixture,BayesianGaussianMixture
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans, MiniBatchKMeans

class ConceptDiscoverer:
    def __init__(self,num_concepts,verbose=True,whiten=True,**kwargs) -> None:
        self.k = num_concepts
        self.gm = MiniBatchKMeans(n_clusters=self.k,verbose=verbose,**kwargs)
        self.whiten_act = whiten
        
    def fit(self,activations: np.ndarray):
        self.mu = activations.mean(axis=0)
        self.sigma = activations.std(axis=0)
        
        whitten_activations = self.whiten(activations)
        
        self.gm.fit(whitten_activations)
    
    def predict(self,activations):
        whitten_activations = self.whiten(activations)
        return self.gm.predict(whitten_activations)
        
    def get_concept_distances(self,activations):
        whitten_activations = self.whiten(activations)
        # Get the distances from the cluster centers
        return self.gm.transform(whitten_activations)
    
    def silouhette_score(self,activations,**kwargs):
        whitten_activations = self.whiten(activations)
        return silhouette_score(whitten_activations,self.gm.predict(whitten_activations),**kwargs)
        
    def whiten(self,obs):
        if(self.whiten_act):
            return (obs - self.mu)/self.sigma
        else:
            return obs



In [25]:
sub_sample_size = 10000
sub_sample = np.random.choice(activations_train.shape[0],sub_sample_size,replace=False)

num_concepts =64

cd = ConceptDiscoverer(num_concepts=num_concepts)
cd.fit(activations_train)

Init 1/3 with method: k-means++
Inertia for init 1/3: 3030.372559
Init 2/3 with method: k-means++
Inertia for init 2/3: 3582.701660
Init 3/3 with method: k-means++
Inertia for init 3/3: 2914.075195
Minibatch iteration 1/283100: mean batch inertia: 13.445592, ewa inertia: 13.445592 
Minibatch iteration 2/283100: mean batch inertia: 16.711725, ewa inertia: 13.447900 
Minibatch iteration 3/283100: mean batch inertia: 18.364779, ewa inertia: 13.451374 
Minibatch iteration 4/283100: mean batch inertia: 14.972295, ewa inertia: 13.452449 
Minibatch iteration 5/283100: mean batch inertia: 15.425530, ewa inertia: 13.453843 
Minibatch iteration 6/283100: mean batch inertia: 15.789164, ewa inertia: 13.455493 
Minibatch iteration 7/283100: mean batch inertia: 15.005930, ewa inertia: 13.456589 
Minibatch iteration 8/283100: mean batch inertia: 15.489277, ewa inertia: 13.458025 
Minibatch iteration 9/283100: mean batch inertia: 18.087726, ewa inertia: 13.461296 
[MiniBatchKMeans] Reassigning 1 clust

In [26]:
train_concept_probs = cd.predict_proba(activations_train)
val_concept_probs = cd.predict_proba(activations_val)

train_concepts = cd.predict(activations_train)
val_concepts = cd.predict(activations_val)

Computing label assignment and total inertia
Computing label assignment and total inertia


## Rejection Sampling
What if we consider only those activations that are we have a certain degree of confidence is in that concept?

In [27]:
def within_acceptance(probabilities:np.ndarray,threshold = 0.99):
    max_prob = probabilities.max(axis=1)
    assert(max_prob.shape[0] == probabilities.shape[0])
    number_surpass = (max_prob > threshold).sum()
    return number_surpass / probabilities.shape[0]

#print("Train: ",within_acceptance(train_concept_probs))
#print("Val: ",within_acceptance(val_concept_probs))

In [28]:
if notebook_settings['save_concepts']:
    np.save("mu.npy",mu)
    np.save("sigma.npy",sigma)

---
# Vizualize Explanations

In [29]:

from sklearn.manifold import TSNE


if notebook_settings['tsne']:
    tsne_sample_size  = 50000
    tsne_sample= np.random.choice(activations_train.shape[0],tsne_sample_size,replace=False)
    position_tsne_sample =TSNE(n_components=2,verbose=2,perplexity = 2000,n_iter=500).fit_transform(activations_train[tsne_sample])


In [30]:

import matplotlib.pyplot as plt
from matplotlib import cm

def plot_tsne_points_with_category(points,classes):
    f = plt.figure(figsize=(10,10))
    # How many different classes are there
    unique_classes = np.unique(classes)
    colours = cm.rainbow(np.linspace(0,1,len(unique_classes)))
    
    # Group the points by their class and then plot all points of the same class at a time
    for i,cl in enumerate(unique_classes):
        points_with_class = (classes==cl)
        v = points[points_with_class]
        x,y = v[:,0],v[:,1]
        plt.scatter(x,y,color=colours[i])

    plt.show()

In [31]:
if notebook_settings['tsne']:
    prediction_tsne_sample = node_predictions_train[tsne_sample].argmax(axis=1)
    ground_tsne_sample = node_ground_train[tsne_sample]
    
    k = 64
    cd = ConceptDiscoverer(num_concepts=k,verbose=False)
    cd.fit(activations_train)
    hard_concept_tsne_sample = cd.predict(activations_train[tsne_sample])

    print(prediction_tsne_sample.shape)
    plot_tsne_points_with_category(position_tsne_sample,ground_tsne_sample)
    plot_tsne_points_with_category(position_tsne_sample,hard_concept_tsne_sample)

---
# Concept Completeness

In [32]:
def one_hot_concept(concept_ids,num_concepts):
    one_hot = np.zeros((len(concept_ids),num_concepts))
    one_hot[np.arange(len(concept_ids)),concept_ids] = 1
    return one_hot

In [33]:


# Takes every observed concept along with the eventual graph prediction and then generates X,Y pairs for training a classifier
def generate_concept_to_prediction_map(model,cd,loader):
    f = ModelForward(model)
    act,pred_prob,ground =  f.forward_node_level(loader)
    
    return cd.predict_proba(act),ground
    

In [34]:

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

class ConceptNodeCancerPredictor:
    def __init__(self,k):
        self.dt = DecisionTreeClassifier(max_leaf_nodes=k)
        self.k = k
        
    
    def fit(self,concepts,graph_ground):
        self.dt.fit(concepts.reshape(-1,1),graph_ground)

    
    def predict(self,concepts):
        return self.predict_proba(concepts.reshape(-1,1)).argmax(axis=1)
    
    def predict_proba(self,concepts):
        return self.dt.predict_proba(concepts)
        
    def get_concept_prob_of_cancer(self):
        return self.dt.predict_proba(np.arange(self.k).reshape(-1,1))

class ConceptGraphClassPredictor:
    def __init__(self,k,trained_model,trained_concept_discoverer, trained_node_graph_class_predictor):
        self.node_predictor = ConceptNodeCancerPredictor(k)
        self.k = k
        self.forwarder = ModelForward(trained_model)
        self.cd = trained_concept_discoverer
        self.node_predictor = trained_node_graph_class_predictor
        
    def get_graph_concept_distribution(self, graph, rejection_threshold = 0.99):
        act,pred_prob,ground = self.forwarder.forward_node_level(loader_from_one_graph(graph))
        graph_pred = pred_prob[0]
        graph_ground = ground[0]
        #ensure graph ground is flattened
        ground = ground.reshape(-1)

        concept_prob = one_hot_concept(self.cd.predict(act),self.k)
        num_nodes = graph.x.shape[0]
        
        assert concept_prob.shape == (num_nodes,self.k)
        
                
        # Filter out concepts that are not accepted
        accepted = (concept_prob.max(axis=1) > rejection_threshold)
        num_remaining_nodes =  accepted.sum()
        
        assert accepted.shape == (num_nodes,)
        concept_prob = concept_prob[accepted]
        assert concept_prob.shape == (num_remaining_nodes,self.k)
        pred_prob = pred_prob[accepted]
        assert pred_prob.shape == (num_remaining_nodes,4)
        ground = ground[accepted]
        assert ground.shape == (num_remaining_nodes,)
        
        concept_dist = concept_prob.mean(axis=0)
        assert concept_dist.shape == (self.k,)
        
        
        # Graph-pred/ground is the prediction and ground truth for the graph (one value)
        # Node-pred/ground is the prediction and ground truth for the graph but one for each node
        return {'concept-dist':concept_dist,'node-pred':pred_prob, 'node-ground':ground, 'graph-pred':graph_pred, 'graph-ground':graph_ground, 'num-nodes':num_nodes, 'num-nodes-above-threshold':num_remaining_nodes}
    
    def vote_on_graph(self,graph,rejection_threshold = 0.99):
        prediction_info = self.get_graph_concept_distribution(graph,rejection_threshold)
        concept_prob_of_class = self.node_predictor.get_concept_prob_of_cancer()
        # Perform a weighted sum vote
        assert prediction_info['concept-dist'].shape == (self.k,)
        
        weighted_sum = prediction_info['concept-dist'].reshape(-1,1)*concept_prob_of_class
        
        assert concept_prob_of_class.shape == (self.k,4)
        
        assert weighted_sum.shape == (self.k,4)
        vote = weighted_sum.sum(axis=0)
                
        assert vote.shape == (4,)
        prediction_info['concept-vote'] = vote
        return prediction_info


Q: For a particular graph, what concepts are present, how many times for each, and what is the contribution of that concept

In [35]:
cncp = ConceptNodeCancerPredictor(num_concepts)
cncp.fit(train_concepts,node_ground_train)
cgcp = ConceptGraphClassPredictor(num_concepts,model,cd,cncp)
    

In [36]:
random_graph_id = np.random.randint(len(val_loader))
graph = train_loader.dataset[random_graph_id]

prediction_info = cgcp.vote_on_graph(graph)

graph_concept_vote = prediction_info['concept-vote']

graph_ground = prediction_info['graph-ground']
graph_prediction = prediction_info['graph-pred']

print("Ground: ",graph_ground)
print("Prediction: ",graph_prediction.argmax().item(),graph_prediction)
print("Concept vote: ",graph_concept_vote.argmax(),graph_concept_vote)

Computing label assignment and total inertia
Ground:  0
Prediction:  0 [0.8672594  0.06075229 0.01158582 0.0604026 ]
Concept vote:  0 [0.40886307 0.27866339 0.11668984 0.1957837 ]


In [37]:

# Import cross entropy and accuracy
from sklearn.metrics import accuracy_score,log_loss

# Across a whole dataset, let us evaluate how good the concept_voter is


class ConceptCompleteness:
    def __init__(self, trained_model,num_concepts,verbose = False,rejection_threshold=0.99,**kwargs):
        self.model = trained_model
        self.k = num_concepts
        self.max_samples_for_concept_discovery = kwargs.get('max_samples_for_concept_discovery',10000)
        self.forwarder = ModelForward(self.model,verbose=verbose)
        self.verbose = verbose
        self.rejection_threshold = rejection_threshold
        self.max_iter = kwargs.get('max_iter',200)
        
    def fit(self,loader):
        activations,node_pred,node_ground = self.forwarder.forward_node_level(loader)
        
        self.cd = ConceptDiscoverer(self.k,verbose=self.verbose,max_iter=self.max_iter)
        subset_of_activations = activations[np.random.choice(activations.shape[0],self.max_samples_for_concept_discovery,replace=False)]
        self.cd.fit(subset_of_activations)
        concepts = self.cd.predict(activations)
        
        self.cncp = ConceptNodeCancerPredictor(self.k)
        self.cncp.fit(concepts,node_ground)
            
        self.cgcp = ConceptGraphClassPredictor(self.k,self.model,self.cd,self.cncp)
        
    def evaluate(self,graph):
        prediction_info = self.cgcp.vote_on_graph(graph,self.rejection_threshold)
        return prediction_info['concept-vote'].argmax(),prediction_info['graph-ground']
            
        
    def completeness(self,loader):
        Y_HAT = []
        Y = []
        for batch in tqdm(loader,disable=not self.verbose):
            for graph_id in range(batch.batch.max()):
                graph = extract_graph(batch,graph_id)
                vote,ground = self.evaluate(graph)
                Y_HAT.append(vote)
                Y.append(ground)
        matches = np.equal(np.array(Y),np.array(Y_HAT)).sum()
        return matches/len(Y)
    
                
    



In [38]:
completeness = ConceptCompleteness(model,num_concepts,max_samples_for_concept_discovery=10000)
completeness.fit(train_loader)
completeness.completeness(val_loader)

0.7543859649122807

# Determine K
- Perform GM
- Get Concept Completeness
- Plot

In [40]:
k_to_inspect = np.arange(5,100,5)
completeness_scores = []

for k in k_to_inspect:
    completeness = ConceptCompleteness(model,k,max_samples_for_concept_discovery=100000)
    completeness.fit(train_loader)
    score = completeness.completeness(val_loader)
    completeness_scores.append((k,score))   
    print(f"K: {k}, , Score: {score}")    
    

K: 5, , Score: 0.631578947368421
K: 10, , Score: 0.6666666666666666
K: 15, , Score: 0.7543859649122807
K: 20, , Score: 0.7017543859649122
K: 25, , Score: 0.7192982456140351
K: 30, , Score: 0.7017543859649122
K: 35, , Score: 0.7368421052631579
K: 40, , Score: 0.7017543859649122
K: 45, , Score: 0.7192982456140351
K: 50, , Score: 0.7368421052631579
K: 55, , Score: 0.7192982456140351
K: 60, , Score: 0.7017543859649122
K: 65, , Score: 0.7192982456140351
K: 70, , Score: 0.7192982456140351
K: 75, , Score: 0.7017543859649122
K: 80, , Score: 0.7719298245614035
K: 85, , Score: 0.7543859649122807
K: 90, , Score: 0.7192982456140351
K: 95, , Score: 0.7192982456140351


In [None]:
print(completeness_scores)
plt.plot(k_to_inspect,completeness_scores)

NameError: name 'completeness_scores' is not defined

## Using Silhoutte Average

In [None]:
k_to_inspect = [2,4,8,16,32,64,80,100]

scores = []
for k in k_to_inspect:
    cd = ConceptDiscoverer(k,verbose=False,whiten=True,max_iter=2000)
    cd.fit(activations_train)
    siho_avg = cd.silouhette_score(activations_train,sample_size=50000)
    scores.append((k,siho_avg))
    print(f"K: {k}, Score: {siho_avg}")

K: 2, Score: 0.17153768241405487
K: 4, Score: 0.21925793588161469


KeyboardInterrupt: 

# Concept Representation



In [47]:
class ConceptGraphGenerator:
    def __init__(self,model_forwarder,concept_discoverer):
        self.forwarder = model_forwarder
        self.cd = concept_discoverer
        
    def generate_concept_graph(self,graph):
        # 1) Get the activations for the graph
        act,pred_prob,ground = self.forwarder.forward_node_level(loader_from_one_graph(graph))
        # 2) Get the concepts
        concepts = self.cd.predict(act)
        graph.concepts = concepts
        graph.activations = act
         


In [43]:
num_concepts =32

cd = ConceptDiscoverer(num_concepts=num_concepts,verbose=False)
cd.fit(activations_train)
cgg = ConceptGraphGenerator(ModelForward(model),cd)

In [46]:
class TopN:
    def __init__(self, n, key=lambda x:x):
        self.n = n
        self.container = []
        self.key = key
    def add(self, element):
        self.container.append(element)
        self.container.sort(reverse=True, key=self.key)
        if len(self.container) > self.n:
            self.container = self.container[:self.n]
    
    def get_top(self):
        return self.container
    
    def __str__(self):
        return str(self.container)


In [51]:
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data

class ConceptRepresentationExtractor:
    def __init__(self,concept_graph_generator, concept_discoverer):
        self.cgg = concept_graph_generator
        self.cd = concept_discoverer
        self.k = self.cd.k
        
    def generate_top_n_representations(self,loader,n,num_hops):
        # Hold a list of the top n representations for each of the k concepts
        top_n_representations = [TopN(n,lambda x:-x[0]) for i in range(self.k)]
        for batch in loader:
            for graph in batch_to_graphs(batch):
                self.cgg.generate_concept_graph(graph)
                for node_id in range(graph.x.shape[0]):
                    subgraph = self.extract_representation(graph,node_id,num_hops)
                    concept = graph.concepts[node_id]
                    act = graph.activations[node_id]
                    # Get the distance to the concept
                    concept_dist = self.cd.get_concept_distances(act)[0][concept]
                    
                    top_n_representations[concept].add((concept_dist,subgraph))

    def extract_representation(self,graph,node_id,n):
        # Create a subgraph around the node
        subset,edges,_,_, = k_hop_subgraph(node_id,n,graph.edge_index)
        #
        subgraph = Data(x=graph.x[subset],edge_index=edges)
        subgraph.concepts = graph.concepts[subset]
        subgraph.image = graph.image
        return subgraph
        

In [54]:
rep_extractor = ConceptRepresentationExtractor(cgg,cd)
rep_extractor.generate_top_n_representations(train_loader,10,2)