# SAGE demo for sentiment analysis

This demo tries to use SAGE for a simple sentiment analysis task on the IMDB dataset. 

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [9]:
import torch
from code.model import RNN
from code.data_module import IMDBDataModule
import pytorch_lightning as pl
import sage

In [3]:
# Load the dataset / build the vocab
data_module = IMDBDataModule()
data_module.setup_datasets()
vocab = data_module.get_vocab()

aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:36<00:00, 2.31MB/s]
100%|██████████| 25000/25000 [00:04<00:00, 5468.27lines/s]


In [4]:
# Start tensorboard
%tensorboard --logdir lightning_logs --bind_all

## Model Configuration

Set the model configuration here:

In [10]:
INPUT_DIM = len(vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1

## Training + Validation


In [12]:
model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
trainer = pl.Trainer(auto_select_gpus=True, fast_dev_run=False, max_epochs=10, overfit_batches=10)
trainer.fit(model, data_module)

GPU available: False, used: False
INFO:lightning:GPU available: False, used: False
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores

  | Name      | Type      | Params
----------------------------------------
0 | embedding | Embedding | 10 M  
1 | rnn       | RNN       | 91 K  
2 | fc        | Linear    | 257   
INFO:lightning:
  | Name      | Type      | Params
----------------------------------------
0 | embedding | Embedding | 10 M  
1 | rnn       | RNN       | 91 K  
2 | fc        | Linear    | 257   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [29]:
vocab.freqs.most_common(30)

[('the', 335746),
 ('.', 327192),
 (',', 276280),
 ('and', 163290),
 ('a', 162473),
 ('of', 145437),
 ('to', 135208),
 ("'", 133857),
 ('is', 107221),
 ('it', 96024),
 ('in', 93307),
 ('i', 87401),
 ('this', 75878),
 ('that', 73153),
 ('s', 62933),
 ('was', 48170),
 ('as', 46807),
 ('for', 44116),
 ('with', 44041),
 ('movie', 43421),
 ('but', 42410),
 ('film', 39459),
 (')', 36175),
 ('(', 35397),
 ('you', 34141),
 ('t', 33927),
 ('on', 33740),
 ('not', 30408),
 ('he', 30012),
 ('are', 29406)]

## Show Feature Importance (using SAGE)

In [20]:
# add an activation at the end
model_activation = torch.nn.Sequential(model, torch.nn.Softmax(dim=1))

In [None]:
imputer = sage.MarginalImputer(model, test[:512])
estimator = sage.PermutationEstimator(imputer, 'mse')
sage_values = estimator(test, Y_test)