In [1]:
# Setup
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import warnings
import spacy
from modified_anchor import anchor_text
import pickle
from myUtils import *
from transformer.utils import *
from dataset.dataset_loader import *
import datetime
import re
%load_ext line_profiler

SEED = 84
torch.manual_seed(SEED)
warnings.simplefilter("ignore")

In [2]:
plt.rcParams['font.size'] = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# can be sentiment/spam/offensive
dataset_name = 'sentiment'
review_parser, label_parser, ds_train, ds_val, _ = sentiment_dataset()

Number of tokens in training samples: 3307
Number of tokens in training labels: 2


In [4]:
model = load_model('gru' , f'transformer/{dataset_name}/gru.pt', review_parser)
model = torch.jit.script(model)

{'embedding_dim': 100, 'batch_size': 32, 'hidden_dim': 256, 'num_layers': 2, 'dropout': 0.3, 'lr': 5e-05, 'early_stopping': 5, 'output_classes': 2}
VanillaGRU(
  (embedding_layer): Embedding(3307, 100)
  (GRU_layer): GRU(100, 256, num_layers=2, dropout=0.3)
  (dropout_layer): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=256, out_features=2, bias=True)
  (log_softmax): LogSoftmax(dim=1)
)


In [5]:
spacy_tokenizer = spacy.load("en_core_web_sm")

In [6]:
# 1 = pad 2=sos 3 = eos
def tokenize(text, max_len):
    sentence = spacy_tokenizer.tokenizer(text)
    input_tokens = [2] + [review_parser.vocab.stoi[word.text] for word in sentence] + [3] + [1]*(max_len-len(sentence))

    return input_tokens

In [7]:
def predict_sentences(sentences):
    half_length = len(sentences)//2
    if(half_length>100):
        return np.concatenate([predict_sentences(sentences[:half_length]), predict_sentences(sentences[half_length:])])
    max_len = max([len(sentence) for sentence in sentences])
    sentences = torch.tensor([tokenize(sentence, max_len) for sentence in sentences], device=device)
    input_tokens = torch.transpose(sentences, 0, 1)
    output = model(input_tokens)

    return torch.argmax(output, dim=1).cpu().numpy()

# Anchor Part

In [8]:
nlp = spacy.load('en_core_web_sm')

In [34]:
anchor_text.AnchorText.set_optimize(True)
explainer = anchor_text.AnchorText(nlp, ['positive', 'negative'], use_unk_distribution=False)

In [10]:
train, train_labels = [re.sub('\s+',' ',' '.join(example.text).strip()) for example in ds_train], [example.label for example in ds_train]
test, test_labels = [re.sub('\s+',' ',' '.join(example.text).strip()) for example in ds_train], [example.label for example in ds_train]

In [91]:
anchor_examples = [example for example in train if len(example) < 90 and len(example)>20]

In [12]:
len(anchor_examples)

2274

In [13]:
from collections import Counter, defaultdict
from nltk.corpus import stopwords
def get_ignored(anchor_sentences):
    stop_words = list(".,#&- \'\"\s\t[]?():!;")
    stop_words.extend(["--", "'s", 'sos', 'eos'])
    stop_words.extend(stopwords.words('english'))
    
    def get_below_occurences(sentences):
        min_value = 1
        c = Counter()
        for sentence in sentences:
            c.update(review_parser.tokenize(sentence))
        return set(w for w in c if c[w]<=min_value)

    return set(stop_words).union(get_below_occurences(anchor_sentences))

In [14]:
ignored = get_ignored(anchor_examples)

## notice!

In [15]:
ignored = []

In [16]:
from collections import Counter, defaultdict
def get_occurences(sentences):
    c = Counter()
    for sentence in sentences:
        c.update([x.text for x in nlp.tokenizer(sentence)])
        
    return c

In [94]:
normal_occurences = get_occurences(anchor_examples)

In [84]:
class BestGroup:
    def __init__(self, occurences):
        self.occurences_left = occurences
        self.best = defaultdict(int)
        self.all = defaultdict(int)
        self.min_val = 0
        self.min_name = None
        self.full = False
        self.factor = 0.75
    
    def update(self, anchor):
        self.occurences_left[anchor]-=1
        
        self.all[anchor]+=1
        
        if anchor in self.best:
            self.best[anchor]+=1
            if anchor == self.min_name:
                self._update_min(anchor, self.best[anchor])
        elif not self.full:
            self.best[anchor] = self.all[anchor]
            
            if len(self.best)==50:
                self.full = True
                self._update_min(anchor, self.best[anchor])
        # in case anchor with equal value was outside the best
        elif self.all[anchor] > self.min_val:
            del self.best[self.min_name]
            self.best[anchor] = self.all[anchor]
            self._update_min(anchor, self.best[anchor]) 
            
    def _update_min(self, candid_name, candid_val):
        for anchor, value in self.best.items():
            if value < candid_val:
                candid_name, candid_val = anchor, value
                break
                
        self.min_name = candid_name
        self.min_val = candid_val
    
    def should_calculate(self, anchor):
        return (self.all[anchor]+self.occurences_left[anchor]) >= self.min_val*self.factor
        

In [19]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [96]:
anchor_examples = anchor_examples[:10]

In [None]:
!nvidia-smi

In [95]:
from modified_anchor import anchor_base
anchor_base.AnchorBaseBeam.best_group = BestGroup(normal_occurences)

In [87]:
anchor_base.AnchorBaseBeam.best_group.best

defaultdict(int, {'bland': 1, 'on': 1, '-': 1})

In [97]:
my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle", optimize = True)
set_seed()
%prun -s cumtime -T profile.txt my_utils.compute_explanations(list(range(len(anchor_examples))))

number 0
[0.948936170212766]
[0.9195402298850575]
[0.9337016574585635]
[0.9008264462809917]
[0.9803921568627451]
[0.8765432098765432]
[0.9083969465648855]
[1.0]
[0.9736842105263158]
[0.9083969465648855]
[0.9205298013245033]
[0.9245283018867925]
number 1
[0.0]
[0.2727272727272727]
[0.18181818181818182]
[0.2727272727272727]
[0.45454545454545453]
[0.36363636363636365]
[0.36363636363636365]
[0.2727272727272727]
[0.36363636363636365]
[0.45454545454545453]
[0.09090909090909091]
[0.2727272727272727]
[0.45454545454545453]
[0.0]
[0.0]
number 2
[0.7619047619047619]
[1.0]
[0.6363636363636364]
[0.6111111111111112]
[0.6363636363636364]
[0.6818181818181818]
[0.6666666666666666]
[0.8064516129032258]
[0.6363636363636364]
[0.5454545454545454]
[0.7804878048780488]
[0.76]
[0.6666666666666666]
[0.868421052631579]
number 3
[0.09090909090909091]
[0.18181818181818182]
[0.36363636363636365]
[0.09090909090909091]
[0.09090909090909091]
[0.45454545454545453]
[0.18181818181818182]
[0.0]
[0.2727272727272727]
[0.0]

         16192105 function calls (16187728 primitive calls) in 23.395 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   23.395   23.395 {built-in method builtins.exec}
        1    0.000    0.000   23.395   23.395 <string>:1(<module>)
        1    0.004    0.004   23.394   23.394 myUtils.py:65(compute_explanations)
       10    0.061    0.006   22.487    2.249 myUtils.py:42(get_exp)
       10    0.013    0.001   22.426    2.243 anchor_text.py:216(explain_instance)
       10    0.007    0.001   21.555    2.155 anchor_base.py:284(anchor_beam)
      607    0.100    0.000   20.732    0.034 anchor_text.py:174(sample_fn)
      574    0.001    0.000   20.576    0.036 anchor_base.py:237(<lambda>)
      574    0.062    0.000   20.575    0.036 anchor_base.py:183(complete_sample_fn)
     5296    0.123    0.000   17.846    0.003 anchor_text.py:97(sample)
     5410    0.101    0.000   15.868    0.003 anchor

In [23]:
###### my_utils = TextUtils(anchor_examples, test, explainer, predict_sentences, ignored,f"profile.pickle")
#%lprun -s -m modified_anchor.anchor_text -m modified_anchor.anchor_base -m myUtils -T profile.txt  my_utils.compute_explanations(list(range(len(anchor_examples))))

In [24]:
print(datetime.datetime.now())

2022-06-25 21:16:15.328675
