<a href="https://colab.research.google.com/github/albert-yue/gcn-explanability/blob/master/gcn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [4]:
%cd gdrive/My\ Drive/gcn_explainability

/content/gdrive/My Drive/gcn_explainability


In [9]:
! git pull

Already up to date.


In [13]:
%ls

[0m[01;34mdata[0m/  gcn.ipynb  logs.pt  [01;34mnotebooks[0m/  README.md  [01;34msrc[0m/


In [13]:
from src.models.gcn import GCN
import torch
import numpy as np
import scipy.sparse as sparse
import scipy

logs = (200, 0.5, 41664, 23, 'logs.pt')
ng20 = (200, 0.5, 80035, 20, 'gcn_20ng_full.pt', 61189, 18846, 11314,7532)

hidden_size,dropout,num_vertices,num_labels,path, num_words, num_docs, training_docs, test_docs = ng20

model = GCN(num_vertices, hidden_size, num_labels, dropout=dropout)
model.load_state_dict(torch.load(path))
model.eval()

GCN(
  (dropout): Dropout(p=0.5, inplace=False)
  (act_func): ReLU()
  (softmax): LogSoftmax()
  (layer1): GraphConv(input_size=80035, output_size=200)
  (layer2): GraphConv(input_size=200, output_size=20)
)

In [0]:
from src.data import Corpus, get_data, get_labels, get_vocabulary
from src.utils import load_sparse_tensor

vocabulary = get_vocabulary('data/20ng-vocabulary.txt')
labels = get_labels('data/20ng-labels.txt')
test_corpus = get_data('data/test-20news-clean.txt', labels)
test_adj_matrix = load_sparse_tensor('data/20ng_full_adj_matrix_norm.pt')
train_corpus = get_data('data/train-20news-clean.txt', labels)


In [0]:
import torch.nn as nn
from torch.autograd import Variable

loss_fn = nn.CrossEntropyLoss()
#input_value = Variable(test_adj_matrix, requires_grad=True)
model.zero_grad()
output = model(test_adj_matrix)
loss = loss_fn(output[-7532:,:], test_corpus.labels())
loss.retain_grad()
loss.backward()



In [16]:
weight_matrices = []
gradients = []
for n, p in model.named_parameters():
  print(n)
  weight_matrices.append(p.detach().numpy())
  gradients.append(p.grad.numpy())
  print(p.grad.shape)

#weight_matrices[1] = weight_matrices[1].T
#gradients[1] = gradients[1].T

layer1.weight
torch.Size([80035, 200])
layer2.weight
torch.Size([200, 20])


In [0]:
def grad_cam_avg(adj, weight_matrices, gradients):
  adj = adj.coalesce()
  V = sparse.csr_matrix((adj.values(), adj.indices()), shape=list(adj.size()))
  F_1 = scipy.special.expit(V.dot(weight_matrices[0]))
  F_2 = scipy.special.expit(V.dot(np.matmul(F_1, weight_matrices[1])))
  F = [F_1, F_2]
  grad_cams = []
  for l in range(2):
    grad_cam = []
    grad_cams.append(grad_cam)
    alphas = np.mean(gradients[l], axis = 0)
    for n in range(num_vertices):
      grad_cam.append(np.maximum(0, np.dot(alphas, F[l][n, :])))

  grad_cam_avgs = np.mean(grad_cams, axis=0)
  return grad_cams, grad_cam_avgs

In [18]:
from src.preprocessing import build_adj_matrix, normalize_adj

#grad_cam for specific docs
indices = [1]
for i in indices:
  single_corpus = Corpus([test_corpus.data[i]])
  adj = build_adj_matrix(single_corpus, vocabulary, num_docs, 1)
  adj = normalize_adj(adj)
  grad_cams, grad_cam_avgs = grad_cam_avg(adj, weight_matrices, gradients)
  print(test_corpus.data[-i].text)
  print(labels[test_corpus.data[-i].label])
  print(labels[np.argmax(output[-i,:].detach().numpy())])
  for name, grad_cam_vals in [('layer1', grad_cams[0]), ('layer2', grad_cams[1]), ('overall', grad_cam_avgs)]:
    print(name+":\n")
    top_words = np.asarray(grad_cam_vals).argsort()[-10:]
    for word_idx in top_words:
      if word_idx < num_words:
        print(vocabulary[word_idx])
      elif word_idx < num_words+training_docs:
        print(labels[train_corpus[word_idx-num_words].label])
      else:
        print(labels[train_corpus[word_idx-num_words-training_docs].label])
  

Building word frequencies per doc


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


Building word frequencies per window


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


Calculating PMIs


HBox(children=(IntProgress(value=0, max=1967), HTML(value='')))


Calculating TF-IDF


HBox(children=(IntProgress(value=0, max=61189), HTML(value='')))


Identities


HBox(children=(IntProgress(value=0, max=80035), HTML(value='')))


['david', '<unk>', 'sacco', '<unk>', 'andrew', 'cmu', 'edu', 'subject', 'abortion', 'private', 'health', 'coverage', 'letters', 'regarding', 'organization', 'misc', 'student', 'carnegie', 'mellon', 'pittsburgh', 'pa', 'lines', '<unk>', '<unk>', 'desire', 'wright', 'edu', 'nntp', 'posting', 'host', '<unk>', 'andrew', 'cmu', 'edu', 'reply', '<unk>', 'desire', 'wright', 'edu', '<unk>', 'apr', '<unk>', 'abortion', 'private', 'user', 'boomer', 'desire', 'wrig', 'writes', 'courts', 'found', 'ok', 'charge', 'women', 'less', 'auto', 'insurance', 'illegal', 'charge', 'health', 'insurance', 'live', 'longer', 'make', 'pay', 'retirement', 'funds', 'legal', 'arena', '<unk>', 'consistent', 'gender', 'issue', 'pa', 'recently', 'gender', 'auto', 'insurance', 'removed', 'point']
talk.religion.misc
talk.politics.mideast
layer1:

articulation
governer
sdscpub
xnews
superglue
innqfk
atheist
proceeded
errr
talk.politics.mideast
layer2:

rec.motorcycles
curtech
seizures
simultaneously
blinking
aired
acf
fa

In [19]:
total = 0
for i in range(1, 7500):
  if test_corpus.data[-i].label == np.argmax(output[-i,:].detach().numpy()):
    print(i)
    total += 1
print(total/7500)

for i in range(1, 7500):
  if test_corpus.data[-i].label != np.argmax(output[-i,:].detach().numpy()):
    print(i)
    break

1270
1468
1641
1643
1689
2682
2700
2846
2911
3028
3029
3040
3094
3119
3150
3186
3204
3211
3212
3216
3250
3283
3294
3295
3350
3352
3765
3929
4091
4172
4203
4240
4350
5072
5896
6452
6825
6969
7087
7146
7174
7301
7331
7350
7374
7399
7416
7459
0.0064
1


In [12]:
def BE(adj, weight_matrices, gradients):
  adj = adj.coalesce()
  V = sparse.csr_matrix((adj.values(), adj.indices()), shape=list(adj.size()))
  F_carrot = [V.dot(F_1), V.dot(F_2)]




NameError: ignored