In [1]:
# Setup
%matplotlib inline
%load_ext autoreload
%autoreload 2
import warnings
import spacy
from optimized_anchor import anchor_text, anchor_base
import pickle
import myUtils
from myUtils import *
from transformer.utils import *
from dataset.dataset_loader import *
import datetime
%load_ext line_profiler

SEED = 84
torch.manual_seed(SEED)
warnings.simplefilter("ignore")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# can be sentiment/spam/offensive
dataset_name = 'sentiment'
text_parser, label_parser, ds_train, ds_val = get_dataset(dataset_name)

Number of tokens in training samples: 3307
Number of tokens in training labels: 2


In [3]:
len(ds_val)

1000

In [4]:
model = load_model('gru' , f'transformer/{dataset_name}/gru.pt', text_parser)
model = torch.jit.script(model)
myUtils.model = model
myUtils.text_parser = text_parser

{'embedding_dim': 100, 'batch_size': 32, 'hidden_dim': 256, 'num_layers': 2, 'dropout': 0.3, 'lr': 5e-05, 'early_stopping': 5, 'output_classes': 2}
VanillaGRU(
  (embedding_layer): Embedding(3307, 100)
  (GRU_layer): GRU(100, 256, num_layers=2, dropout=0.3)
  (dropout_layer): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=256, out_features=2, bias=True)
  (log_softmax): LogSoftmax(dim=1)
)


In [5]:
nlp = spacy.load('en_core_web_sm')

In [6]:
train, train_labels, test, test_labels, anchor_examples = preprocess_examples(ds_train)

In [7]:
ignored = get_ignored(anchor_examples)
normal_occurences = get_occurences(anchor_examples)
anchor_base.AnchorBaseBeam.best_group = BestGroup(normal_occurences)

## notice!

In [8]:
ignored = []

In [9]:
anchor_examples = anchor_examples[:1]

In [10]:
!nvidia-smi

Wed Jul 20 22:16:13 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.54       Driver Version: 510.54       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 6000     Off  | 00000000:AF:00.0 Off |                  Off |
| 33%   33C    P2    63W / 260W |   1221MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [11]:
optimize = True
anchor_text.AnchorText.set_optimize(optimize)
explainer = anchor_text.AnchorText(nlp, ['positive', 'negative'], use_unk_distribution=False)

In [48]:
my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences_optimized, ignored, f"profile.pickle", optimize=optimize)
myUtils.model = model
myUtils.text_parser = text_parser
set_seed()
%prun -s cumtime -T profile.txt my_utils.compute_explanations(list(range(len(anchor_examples))))

number 0
[0.9551020408163265]
[0.922077922077922]
[0.9337016574585635]
[0.7619047619047619]
[0.9803921568627451]
[0.8292682926829268]
[0.9337016574585635]
[0.9022556390977443]
[0.9178082191780822]
[0.8641975308641975]
[1.0]
[0.847457627118644]
 
*** Profile printout saved to text file 'profile.txt'. 


         1774605 function calls (1774497 primitive calls) in 1.888 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.901    1.901 {built-in method builtins.exec}
        1    0.000    0.000    1.901    1.901 <string>:1(<module>)
        1    0.000    0.000    1.901    1.901 myUtils.py:277(compute_explanations)
        1    0.003    0.003    1.839    1.839 myUtils.py:254(get_exp)
        1    0.000    0.000    1.835    1.835 anchor_text.py:218(explain_instance)
        1    0.001    0.001    1.778    1.778 anchor_base.py:259(anchor_beam)
      138    0.014    0.000    1.760    0.013 anchor_text.py:178(sample_fn)
      134    0.000    0.000    1.687    0.013 anchor_base.py:221(<lambda>)
      134    0.005    0.000    1.686    0.013 anchor_base.py:190(complete_sample_fn)
     1308    0.018    0.000    1.625    0.001 anchor_text.py:97(sample)
     1320    0.010    0.000    1.217    0.001 anchor_

In [13]:
###### my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle")
#%lprun -s -m modified_anchor.anchor_text -m modified_anchor.anchor_base -m myUtils -T profile.txt  my_utils.compute_explanations(list(range(len(anchor_examples))))

In [14]:
print(datetime.datetime.now())

2022-07-20 22:16:23.252859
