In [1]:
# Setup
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import warnings
import spacy
from modified_anchor import anchor_text
import pickle
from myUtils import *
from transformer.utils import *
from dataset.dataset_loader import *
import datetime
%load_ext line_profiler

SEED = 84
torch.manual_seed(SEED)
warnings.simplefilter("ignore")

In [None]:
%prun?

In [24]:
%lprun -f sum_of -f check -T alon.txt sum_of(1000)


*** Profile printout saved to text file 'alon.txt'. 


Timer unit: 1e-06 s

Total time: 0.002409 s
File: /tmp/ipykernel_56833/2452235315.py
Function: sum_of at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def sum_of(N):
     2         1          3.0      3.0      0.1      total = 0
     3         1         11.0     11.0      0.5      check()
     4         6          9.0      1.5      0.4      for i in range(5):
     5         5       2318.0    463.6     96.2          L = [j ^ (j >> i) for j in range(N)]
     6         5         67.0     13.4      2.8          total += sum(L)
     7         1          1.0      1.0      0.0      return total

Total time: 3e-06 s
File: /tmp/ipykernel_56833/3948132612.py
Function: check at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def check():
     2         1          1.0      1.0     33.3      x = 5
     3         1          1.0      1.0     33.3      y=17


In [2]:
plt.rcParams['font.size'] = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# can be sentiment/spam/offensive
dataset_name = 'sentiment'
review_parser, label_parser, ds_train, ds_val, _ = create_sentiment_dataset()

Number of tokens in training samples: 3307
Number of tokens in training labels: 2


In [4]:
model = load_model('gru' , f'transformer/{dataset_name}/gru.pt', review_parser)

{'embedding_dim': 100, 'batch_size': 32, 'hidden_dim': 256, 'num_layers': 2, 'dropout': 0.3, 'lr': 5e-05, 'early_stopping': 5, 'output_classes': 2}
VanillaGRU(
  (embedding_layer): Embedding(3307, 100)
  (GRU_layer): GRU(100, 256, num_layers=2, dropout=0.3)
  (dropout_layer): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=256, out_features=2, bias=True)
  (log_softmax): LogSoftmax(dim=1)
)


In [5]:
# 1 = pad 2=sos 3 = eos
def tokenize(text, max_len):
    sentence = review_parser.tokenize(str(text))
    input_tokens = [2] + [review_parser.vocab.stoi[word] for word in sentence] + [3] + [1]*(max_len-len(sentence))

    return input_tokens

In [6]:
def predict_sentences(sentences):
    half_length = len(sentences)//2
    if(half_length>100):
        return np.concatenate([predict_sentences(sentences[:half_length]), predict_sentences(sentences[half_length:])])
    max_len = max([len(sentence) for sentence in sentences])
    sentences = torch.tensor([tokenize(sentence, max_len) for sentence in sentences]).to(device)
    input_tokens = torch.transpose(sentences, 0, 1)
    output = model(input_tokens)

    return torch.argmax(output, dim=1).cpu().numpy()

# Anchor Part

In [7]:
nlp = spacy.load('en_core_web_sm')

In [8]:
explainer = anchor_text.AnchorText(nlp, ['positive', 'negative'], use_unk_distribution=False)

In [9]:
train, train_labels = [' '.join(example.text) for example in ds_train], [example.label for example in ds_train]
test, test_labels = [' '.join(example.text) for example in ds_train], [example.label for example in ds_train]

In [10]:
anchor_examples = [example for example in train if len(example) < 90 and len(example)>20]

In [59]:
len(anchor_examples)

2272

In [13]:
from collections import Counter, defaultdict
from nltk.corpus import stopwords
def get_ignored(anchor_sentences):
    sentences = [[x.text for x in nlp(sentence)] for sentence in anchor_sentences]
    min_occurence = 1
    c = Counter()
    stop_words = list(".,- \'\"\s[]?():!;")
    stop_words.extend(["--", "'s", 'sos', 'eos'])
    stop_words.extend(stopwords.words('english'))
    """
    for sentence in sentences:
        c.update(sentence)
    sums = 0
    for ignore_s in stop_words:
        sums+=c[ignore_s]
        del c[ignore_s]
    print(sums)
    ignored_anchors = stop_words
    for key in c.keys():
        if c[key]<=min_occurence:
            ignored_anchors.append(key)
    print(len(c.keys()))
    return ignored_anchors
    """
    return stop_words

In [14]:
ignored = get_ignored(anchor_examples)

## notice!

In [11]:
ignored = []

In [12]:
print(datetime.datetime.now())

2022-06-09 16:59:40.851284


In [13]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [14]:
anchor_examples = anchor_examples[:1]

In [179]:
my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle")
set_seed()
%prun -s cumtime -T profile.txt my_utils.compute_explanations(list(range(len(anchor_examples))))

number 0
[1.0]
[1.0]
[0.9572649572649573]
[1.0]
[0.8823529411764706]
[0.9803921568627451]
[0.9402985074626866]
[0.9381443298969072]
[0.8877551020408163]
[0.89]
[0.9298245614035088]
[0.948051948051948]
[1.0]
 
*** Profile printout saved to text file 'profile.txt'. 


         5037417 function calls (4998212 primitive calls) in 8.333 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    8.334    8.334 {built-in method builtins.exec}
        1    0.000    0.000    8.334    8.334 <string>:1(<module>)
        1    0.000    0.000    8.334    8.334 myUtils.py:62(compute_explanations)
        1    0.010    0.010    8.239    8.239 myUtils.py:39(get_exp)
        1    0.002    0.002    8.229    8.229 anchor_text.py:190(explain_instance)
        1    0.001    0.001    8.115    8.115 anchor_base.py:275(anchor_beam)
      154    0.030    0.000    7.730    0.050 anchor_text.py:153(sample_fn)
      148    0.000    0.000    7.696    0.052 anchor_base.py:228(<lambda>)
      148    0.009    0.000    7.695    0.052 anchor_base.py:175(complete_sample_fn)
     1468    0.035    0.000    6.948    0.005 anchor_text.py:89(sample)
     1481    0.005    0.000    4.999    0.003 anchor_te

In [None]:
###### my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle")
%lprun -s -m modified_anchor.anchor_text -m modified_anchor.anchor_base -m myUtils -T profile.txt  my_utils.compute_explanations(list(range(len(anchor_examples))))

In [None]:
print(datetime.datetime.now())