In [1]:
import captum
import torch
import utils
import model.net as net
from model.data_loader import DataLoader
from torch import device as dev
from captum.attr import LayerIntegratedGradients, LayerGradientShap, LayerFeatureAblation, TokenReferenceBase, visualization
from captum_viz import interpret_sequence

In [None]:
%load_ext autoreload

Parameter Setup

In [2]:
MODEL_DIR = 'experiments/base_model/'
DATA_DIR = 'data/'
sentences_file = DATA_DIR + 'train/sentences.txt'
labels_file = DATA_DIR + 'train/labels.txt'
weights = MODEL_DIR + 'best.pth'
params = utils.Params(MODEL_DIR+'params.json')
params.vocab_size = 25
params.number_of_classes = 10
params.cuda = torch.cuda.is_available()

Data setup

In [3]:
loader = DataLoader(DATA_DIR, params)
data = loader.load_data(['train', 'val'], DATA_DIR)
train_data = data['train']
train_data_iterator = loader.data_iterator(train_data, params, shuffle=False)
train_batch, label_batch = next(train_data_iterator)

classes = ['Extracellular', 'Plastid', 'Cytoplasm', 'Mitochondrion', 
'Nucleus', 'ER', 'Golgi', 'Membrane', 'Lysosome', 'Peroxisome']

Model setup

In [4]:
model = net.Net(params).cuda() if params.cuda else net.Net(params)
checkpoint = torch.load(weights, map_location=dev('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model.eval()

Net(
  (embedding): Embedding(25, 20)
  (lstm): LSTM(20, 40, num_layers=2, dropout=0.5, bidirectional=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (attention): Attention(
    (linear_in): Linear(in_features=512, out_features=512, bias=False)
    (linear_out): Linear(in_features=1024, out_features=512, bias=False)
    (softmax): Softmax(dim=-1)
    (tanh): Tanh()
  )
  (fc): Linear(in_features=80, out_features=10, bias=True)
)

In [5]:
bool(model.use_attention)

False

**Visualizing Integrated Gradient**

In [6]:
sentences = open(sentences_file).read().splitlines()
labels = open(labels_file).read().splitlines()
token_reference = TokenReferenceBase(reference_token_idx=loader.pad_ind)
layer_ig = LayerIntegratedGradients(model, model.embedding)
ig_vis_data_records = []
ig_attn_vis_data_records = []

In [7]:
interpret_sequence(model, sentences[:5], train_batch, layer_ig, ig_vis_data_records)

pred:  Extracellular ( 0.76 ) , delta:  5.581566488166434e-08
pred:  Cytoplasm ( 0.64 ) , delta:  1.432005203305664e-08
pred:  Cytoplasm ( 0.69 ) , delta:  2.3265631227120664e-08
pred:  Mitochondrion ( 0.83 ) , delta:  1.6991394513610203e-08
pred:  Cytoplasm ( 0.69 ) , delta:  1.871156388166817e-08


In [8]:
visualization.visualize_text(ig_vis_data_records)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Extracellular,Extracellular (0.76),location,0.95,M C A K I A L L L V L V G A A S A A V L P D K F Y G T F D L D H S E N F D E Y L T A K G Y G W F T R K L V T F A T F K K V F T K T S N K N L F D Y S N L T S K K D V H Y K N V Q L G K A F Q G E G L D S T K H E I T F T L K D G H L F E H H K P L E G G D A K E E T Y E Y L F D K E F L L V R M S F N G V E G R R F Y K R L P
,,,,
Plastid,Cytoplasm (0.64),location,3.4,M A L Y G T L Q L S H S L G L C R N Q R F C N P E N S A M R R R L H I S N G P L S L G V P L G Q H G F S N I L L S N Y L R R P I C S V P C R T T A F R C H S F S A S G K A I E P A V K A V T V V L T K S H G L M Q Q F P F V Y K L V P A V A L L V F S L W G L V P F A R Q G R N I L L N K N D N G W K K S G T Y H V M T S Y V Q P L L L W L G A L F I C R A L D P V V L P T E A S K I V K D R L L N F V R S L S T V L A F A Y C L S S L I Q Q T Q K L F S E T S N P S D T R N M G F Q F A G K A L Y S A V W V A A V S L F M E L L G F S T Q K W L T A G G L G T V L I T L A G R E I L T N F L S S V M I H A T R P F V L N E W I Q T K I E G Y E V S G T V E H V G W W S P T I I R G E D R E A I H I P N H K F T V N V V R N L T Q K T H W R I K T H L A I S H L D V N K I N N I V A D M R K V L A K N P M V E Q Q R L H R R V F L E N V I P E N Q A L S I L I S C F V K T S H H E E Y L G V K E A I L L D L L R V I S H H R A R L A T P I R T I R K M Y T E T D V E N T P F G E S M Y G G V T S R R P L M L I E P A Y K I N G E D K S K S Q N R A A K P T A E Q E N K G S N P K S K E T S S P D L K A N V K V G E S P V S D T N K V P E E T V A K P V I K A V S K P P T P K D T E T S G T E K P K A K R S G G T I K S T K T D E T D S S T S S A S R S T L E E N I V L G V A L E G S K R T L P I E E E I H S P P M E T D A K E L T G A R R S G G N G P L V A D K E Q K D S Q S Q P N S G A S T E P
,,,,
Cytoplasm,Cytoplasm (0.69),location,1.1,M S A F K P Y T E A L E V L K K Y E K K D G L S I D D L I R H N F Q G G L T F N D F L I L P G Y I D F V P N N V S L E T R I S R N I V L K T P F M S S P M D T V T E D Q M A I Y M A L L G G I G V I H H N C T P E E Q A A M V R K V K K Y E N G F I L D P V V F S P Q H T V G D V L K I K E T K G F S G I P I T E N G K L R G K L V G I V T S R D V Q F H K D T N T P V T E V M T P R E E L I T T A E G I S L E R A N E M L R K S K K G K L P V V D K D D N L V A L L S L T D L M K N L H F P L A S K T S D T K Q L M V A A A I G T R D D D R T R L A L L A E A G L D A V V I D S S Q G N S C F Q I E M I K W I K K T Y P K I D V I A G N V V T R E Q T A S L I A A G A D G L R V G M G S G S A C I T Q E V M A C G R P Q A T A I A Q V A E F A S Q F G I G V I A D G G I Q N V G H M V K S L S L G A T A V M M G G L L A G T T E S P G E Y Y V R E G Q R Y K S Y R G M G S I A A M E G T G V N K N A S T G R Y F S E N D A V R V A Q G V S G L V V D K G S L L R F L P Y L Y T G L Q H A L Q D I G T K S L D E L H E A V D K H E V R F E L R S S A A I R E G D I Q G F A T Y E K R L Y
,,,,
Mitochondrion,Mitochondrion (0.83),location,2.98,M L K L A R P F I P P L S R N N A I S S G I V L T S R R F Q S S F T F L S N Q S L L S K N Q M K S K R K K G S K K A A Y H R Q P P E H E H T A P L I K Q N K T I T K K E H S D V R G S H L K K K R S D F S W L P R V P S T S H L K Q S D M T T N V L Y S G Y R P L F I N P N D P K L K E D T G S T L Y E F A M K L E D L N E P L S P W I S S A T G L E F F S E W E N I P S E L L K N L K P F H P P K E K S M N T N E L I H V S A K R N T L V D N K T S E T L Q R K M D E F S K R R G K G R K K S V V T L L Q M K K K L E G
,,,,
Nucleus,Cytoplasm (0.69),location,0.94,M P S C D P G P G P A C L P T K T F R S Y L P R C H R T Y S C V H C R A H L A K H D E L I S K S F Q G S H G R A Y L F N S V V N V G C G P A E Q R L L L T G L H S V A D I F C E S C K T T L G W K Y E Q A F E T S Q K Y K E G K Y I I E M S H M V K D N G W D
,,,,


In [9]:
params.attn = 1
attn_model = net.Net(params).cuda() if params.cuda else net.Net(params)
checkpoint = torch.load(weights, map_location=dev('cpu'))
attn_model.load_state_dict(checkpoint['state_dict'])
attn_model.eval()

Net(
  (embedding): Embedding(25, 20)
  (lstm): LSTM(20, 40, num_layers=2, dropout=0.5, bidirectional=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (attention): Attention(
    (linear_in): Linear(in_features=512, out_features=512, bias=False)
    (linear_out): Linear(in_features=1024, out_features=512, bias=False)
    (softmax): Softmax(dim=-1)
    (tanh): Tanh()
  )
  (fc): Linear(in_features=80, out_features=10, bias=True)
)

In [10]:
bool(attn_model.use_attention)

True

In [11]:
layer_ig_attn = LayerIntegratedGradients(attn_model, attn_model.embedding)
interpret_sequence(attn_model, sentences[:5], train_batch, layer_ig_attn, ig_attn_vis_data_records)

pred:  Extracellular ( 0.99 ) , delta:  8.579375077033546e-08
pred:  Plastid ( 1.00 ) , delta:  8.487881277829956e-07
pred:  Cytoplasm ( 0.88 ) , delta:  3.6450237450580403e-07
pred:  Mitochondrion ( 1.00 ) , delta:  2.3572954965800363e-07
pred:  Cytoplasm ( 0.82 ) , delta:  7.547191882562032e-08


In [12]:
visualization.visualize_text(ig_attn_vis_data_records)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Extracellular,Extracellular (0.99),location,3.22,M C A K I A L L L V L V G A A S A A V L P D K F Y G T F D L D H S E N F D E Y L T A K G Y G W F T R K L V T F A T F K K V F T K T S N K N L F D Y S N L T S K K D V H Y K N V Q L G K A F Q G E G L D S T K H E I T F T L K D G H L F E H H K P L E G G D A K E E T Y E Y L F D K E F L L V R M S F N G V E G R R F Y K R L P
,,,,
Plastid,Plastid (1.00),location,3.09,M A L Y G T L Q L S H S L G L C R N Q R F C N P E N S A M R R R L H I S N G P L S L G V P L G Q H G F S N I L L S N Y L R R P I C S V P C R T T A F R C H S F S A S G K A I E P A V K A V T V V L T K S H G L M Q Q F P F V Y K L V P A V A L L V F S L W G L V P F A R Q G R N I L L N K N D N G W K K S G T Y H V M T S Y V Q P L L L W L G A L F I C R A L D P V V L P T E A S K I V K D R L L N F V R S L S T V L A F A Y C L S S L I Q Q T Q K L F S E T S N P S D T R N M G F Q F A G K A L Y S A V W V A A V S L F M E L L G F S T Q K W L T A G G L G T V L I T L A G R E I L T N F L S S V M I H A T R P F V L N E W I Q T K I E G Y E V S G T V E H V G W W S P T I I R G E D R E A I H I P N H K F T V N V V R N L T Q K T H W R I K T H L A I S H L D V N K I N N I V A D M R K V L A K N P M V E Q Q R L H R R V F L E N V I P E N Q A L S I L I S C F V K T S H H E E Y L G V K E A I L L D L L R V I S H H R A R L A T P I R T I R K M Y T E T D V E N T P F G E S M Y G G V T S R R P L M L I E P A Y K I N G E D K S K S Q N R A A K P T A E Q E N K G S N P K S K E T S S P D L K A N V K V G E S P V S D T N K V P E E T V A K P V I K A V S K P P T P K D T E T S G T E K P K A K R S G G T I K S T K T D E T D S S T S S A S R S T L E E N I V L G V A L E G S K R T L P I E E E I H S P P M E T D A K E L T G A R R S G G N G P L V A D K E Q K D S Q S Q P N S G A S T E P
,,,,
Cytoplasm,Cytoplasm (0.88),location,1.8,M S A F K P Y T E A L E V L K K Y E K K D G L S I D D L I R H N F Q G G L T F N D F L I L P G Y I D F V P N N V S L E T R I S R N I V L K T P F M S S P M D T V T E D Q M A I Y M A L L G G I G V I H H N C T P E E Q A A M V R K V K K Y E N G F I L D P V V F S P Q H T V G D V L K I K E T K G F S G I P I T E N G K L R G K L V G I V T S R D V Q F H K D T N T P V T E V M T P R E E L I T T A E G I S L E R A N E M L R K S K K G K L P V V D K D D N L V A L L S L T D L M K N L H F P L A S K T S D T K Q L M V A A A I G T R D D D R T R L A L L A E A G L D A V V I D S S Q G N S C F Q I E M I K W I K K T Y P K I D V I A G N V V T R E Q T A S L I A A G A D G L R V G M G S G S A C I T Q E V M A C G R P Q A T A I A Q V A E F A S Q F G I G V I A D G G I Q N V G H M V K S L S L G A T A V M M G G L L A G T T E S P G E Y Y V R E G Q R Y K S Y R G M G S I A A M E G T G V N K N A S T G R Y F S E N D A V R V A Q G V S G L V V D K G S L L R F L P Y L Y T G L Q H A L Q D I G T K S L D E L H E A V D K H E V R F E L R S S A A I R E G D I Q G F A T Y E K R L Y
,,,,
Mitochondrion,Mitochondrion (1.00),location,3.47,M L K L A R P F I P P L S R N N A I S S G I V L T S R R F Q S S F T F L S N Q S L L S K N Q M K S K R K K G S K K A A Y H R Q P P E H E H T A P L I K Q N K T I T K K E H S D V R G S H L K K K R S D F S W L P R V P S T S H L K Q S D M T T N V L Y S G Y R P L F I N P N D P K L K E D T G S T L Y E F A M K L E D L N E P L S P W I S S A T G L E F F S E W E N I P S E L L K N L K P F H P P K E K S M N T N E L I H V S A K R N T L V D N K T S E T L Q R K M D E F S K R R G K G R K K S V V T L L Q M K K K L E G
,,,,
Nucleus,Cytoplasm (0.82),location,1.31,M P S C D P G P G P A C L P T K T F R S Y L P R C H R T Y S C V H C R A H L A K H D E L I S K S F Q G S H G R A Y L F N S V V N V G C G P A E Q R L L L T G L H S V A D I F C E S C K T T L G W K Y E Q A F E T S Q K Y K E G K Y I I E M S H M V K D N G W D
,,,,


**GradientSHAP Visualization**

In [49]:
%autoreload 2

In [50]:
layer_shap = LayerGradientShap(model, model.embedding)
layer_shap_attn = LayerGradientShap(attn_model, attn_model.embedding)
shap_vis_data_records = []
shap_attn_vis_data_records = []

In [51]:
interpret_sequence(model, sentences[:5], train_batch.long(), layer_shap, shap_vis_data_records)

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)

In [None]:
visualization.visualize_text(shap_vis_data_records)

In [None]:
interpret_sequence(attn_model, sentences[:5], train_batch, layer_shap_attn, shap_attn_vis_data_records)

In [None]:
visualization.visualize_text(shap_vis_attn_data_records)

**Feature Ablation Visualization**

In [41]:
model

Net(
  (embedding): Embedding(25, 20)
  (lstm): LSTM(20, 40, num_layers=2, dropout=0.5, bidirectional=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (attention): Attention(
    (linear_in): Linear(in_features=512, out_features=512, bias=False)
    (linear_out): Linear(in_features=1024, out_features=512, bias=False)
    (softmax): Softmax(dim=-1)
    (tanh): Tanh()
  )
  (fc): Linear(in_features=80, out_features=10, bias=True)
)

In [52]:
layer_fa = LayerFeatureAblation(model, model.embedding)
fa_vis_data_records = []

In [55]:
interpret_sequence(model, sentences[:5], model.embedding(train_batch).long(), layer_fa, fa_vis_data_records)

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)

In [None]:
layer_fa_attn = LayerFeatureAblation(model, attn_model.embedding)
fa_attn_vis_data_records = []

In [36]:
interpret_sequence(attn_model, sentences[:5], train_batch, layer_fa_attn, fa_attn_vis_data_records)

RuntimeError: The size of tensor a (673) must match the size of tensor b (20) at non-singleton dimension 3