In [1]:
# Setup
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import warnings
import spacy
from modified_anchor import anchor_text
import pickle
from myUtils import *
from transformer.utils import *
from dataset.dataset_loader import *
import datetime
%load_ext line_profiler

SEED = 84
torch.manual_seed(SEED)
warnings.simplefilter("ignore")

In [None]:
%prun?

In [None]:
%lprun -f sum_of -f check -T alon.txt sum_of(1000)

In [2]:
plt.rcParams['font.size'] = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# can be sentiment/spam/offensive
dataset_name = 'sentiment'
review_parser, label_parser, ds_train, ds_val, _ = create_sentiment_dataset()

Number of tokens in training samples: 3307
Number of tokens in training labels: 2


In [4]:
model = load_model('gru' , f'transformer/{dataset_name}/gru.pt', review_parser)

{'embedding_dim': 100, 'batch_size': 32, 'hidden_dim': 256, 'num_layers': 2, 'dropout': 0.3, 'lr': 5e-05, 'early_stopping': 5, 'output_classes': 2}
VanillaGRU(
  (embedding_layer): Embedding(3307, 100)
  (GRU_layer): GRU(100, 256, num_layers=2, dropout=0.3)
  (dropout_layer): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=256, out_features=2, bias=True)
  (log_softmax): LogSoftmax(dim=1)
)


In [5]:
# 1 = pad 2=sos 3 = eos
def tokenize(text, max_len):
    sentence = review_parser.tokenize(text)
    input_tokens = [2] + [review_parser.vocab.stoi[word] for word in sentence] + [3] + [1]*(max_len-len(sentence))

    return input_tokens

In [6]:
def predict_sentences(sentences):
    half_length = len(sentences)//2
    if(half_length>100):
        return np.concatenate([predict_sentences(sentences[:half_length]), predict_sentences(sentences[half_length:])])
    max_len = max([len(sentence) for sentence in sentences])
    sentences = torch.tensor([tokenize(sentence, max_len) for sentence in sentences], device=device)
    input_tokens = torch.transpose(sentences, 0, 1)
    output = model(input_tokens)

    return torch.argmax(output, dim=1).cpu().numpy()

# Anchor Part

In [7]:
nlp = spacy.load('en_core_web_sm')

In [8]:
explainer = anchor_text.AnchorText(nlp, ['positive', 'negative'], use_unk_distribution=False)

In [9]:
train, train_labels = [' '.join(example.text) for example in ds_train], [example.label for example in ds_train]
test, test_labels = [' '.join(example.text) for example in ds_train], [example.label for example in ds_train]

In [10]:
anchor_examples = [example for example in train if len(example) < 90 and len(example)>20]

In [59]:
len(anchor_examples)

2272

In [13]:
from collections import Counter, defaultdict
from nltk.corpus import stopwords
def get_ignored(anchor_sentences):
    sentences = [[x.text for x in nlp(sentence)] for sentence in anchor_sentences]
    min_occurence = 1
    c = Counter()
    stop_words = list(".,- \'\"\s[]?():!;")
    stop_words.extend(["--", "'s", 'sos', 'eos'])
    stop_words.extend(stopwords.words('english'))
    """
    for sentence in sentences:
        c.update(sentence)
    sums = 0
    for ignore_s in stop_words:
        sums+=c[ignore_s]
        del c[ignore_s]
    print(sums)
    ignored_anchors = stop_words
    for key in c.keys():
        if c[key]<=min_occurence:
            ignored_anchors.append(key)
    print(len(c.keys()))
    return ignored_anchors
    """
    return stop_words

In [14]:
ignored = get_ignored(anchor_examples)

## notice!

In [11]:
ignored = []

In [12]:
print(datetime.datetime.now())

2022-06-09 16:59:40.851284


In [12]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [13]:
anchor_examples = anchor_examples[:1]

In [30]:
!nvidia-smi

Sat Jun 11 15:55:46 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:82:00.0 Off |                  N/A |
| 46%   80C    P2   239W / 250W |   5868MiB / 11178MiB |     73%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [31]:
my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle")
set_seed()
%prun -s cumtime -T profile.txt my_utils.compute_explanations(list(range(len(anchor_examples))))

number 0
[1.0]
[1.0]
[0.9572649572649573]
[1.0]
[0.8823529411764706]
[0.9803921568627451]
[0.9402985074626866]
[0.9381443298969072]
[0.8877551020408163]
[0.89]
[0.9298245614035088]
[0.948051948051948]
[1.0]
 
*** Profile printout saved to text file 'profile.txt'. 


         3339626 function calls (3300421 primitive calls) in 4.552 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    4.553    4.553 {built-in method builtins.exec}
        1    0.000    0.000    4.553    4.553 <string>:1(<module>)
        1    0.000    0.000    4.553    4.553 myUtils.py:62(compute_explanations)
        1    0.010    0.010    4.470    4.470 myUtils.py:39(get_exp)
        1    0.004    0.004    4.459    4.459 anchor_text.py:210(explain_instance)
        1    0.001    0.001    4.335    4.335 anchor_base.py:276(anchor_beam)
      154    0.022    0.000    4.234    0.027 anchor_text.py:168(sample_fn)
      148    0.000    0.000    4.084    0.028 anchor_base.py:229(<lambda>)
      148    0.010    0.000    4.083    0.028 anchor_base.py:175(complete_sample_fn)
     1468    0.029    0.000    3.555    0.002 anchor_text.py:99(sample)
     1481    0.015    0.000    2.940    0.002 anchor_te

In [None]:
###### my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle")
%lprun -s -m modified_anchor.anchor_text -m modified_anchor.anchor_base -m myUtils -T profile.txt  my_utils.compute_explanations(list(range(len(anchor_examples))))

In [None]:
print(datetime.datetime.now())