In [1]:
from os.path import join as pathjoin
from data_processing import *
from interpretation import *
from models import *
from training import *
import tqdm

In [2]:
DATA_DIR = '/home/mlepekhin/data/'
MODELS_DIR = '/home/mlepekhin/models'
transformer_model = 'DeepPavlov/rubert-base-cased'
MAX_TOKENS = 512

In [3]:
import pandas as pd

topic_generated_df = pd.read_csv(pathjoin(DATA_DIR, 'min_gpt_bpe/ru_train_topic_big_sep_generators.csv'))
topic_list = np.unique(topic_generated_df.topic.values)

print(topic_list)

['arhitecture' 'business' 'crime' 'education' 'games' 'literature' 'music'
 'politics' 'sport' 'travel']


In [4]:
from sklearn.metrics import accuracy_score

In [5]:
test_df = pd.read_csv(pathjoin(DATA_DIR, "ru_test"))
test_df.head()

Unnamed: 0.1,Unnamed: 0,target,text
0,726,A7,Глава 1 Приступая к работе 1.1 Знакомство с те...
1,1871,A17,Kawasaki D-Tracker С недавних пор Kawasaki d-t...
2,1265,A17,"По моему , вполне достойные книги , может и не..."
3,205,A11,Тест-драйв Lada Granta : новая надежда автогра...
4,141,A8,"среда , 2 декабря 2009 года , 12.33 Бумага всё..."


In [6]:
def interpret_topic(topic, calc_triggers):
    MODEL_ID = f'allennlp_rubert_from_topic_generated_{topic}'
    CHECKPOINTS_DIR = pathjoin(MODELS_DIR, MODEL_ID, 'checkpoints')
    BEST_MODEL = pathjoin(CHECKPOINTS_DIR, 'best.th')
    
    vocab = Vocabulary().from_files(pathjoin(MODELS_DIR, MODEL_ID, 'vocab'))
    model = build_transformer_model(vocab, transformer_model)
    if torch.cuda.is_available():
        cuda_device = 1
    else:
        cuda_device = -1
    model.load_state_dict(torch.load(BEST_MODEL, map_location=f'cuda:{cuda_device}'))
    
    dataset_reader = build_transformer_dataset_reader(transformer_model, MAX_TOKENS)
    predictor = TextClassifierPredictor(model, dataset_reader=dataset_reader)
    predicted_classes = np.array(predict_classes(test_df.text.values, predictor, vocab))
    good_sentences = get_all_correctly_predicted_sentences(
        test_df.text.values, test_df.target.values, predicted_classes
    )
    smooth_grad = SmoothGradient(predictor)
    triggers = []
    if calc_triggers:
        triggers = get_most_frequent_trigger_words(good_sentences, dataset_reader.tokenizer, 50, smooth_grad)
        
    return accuracy_score(test_df.target.values, predicted_classes), triggers

In [7]:
trigger_words = {}
accuracy = {}

In [8]:
for topic in tqdm.tqdm(topic_list):
    cur_accuracy, cur_triggers = interpret_topic(topic, topic in ['arhitecture', 'business', 'politics'])
    accuracy[topic] = cur_accuracy
    trigger_words[topic] = cur_triggers

  0%|          | 0/10 [00:00<?, ?it/s]

Building the model


 10%|█         | 1/10 [43:08<6:28:13, 2588.18s/it]

Building the model


 20%|██        | 2/10 [1:24:23<5:40:34, 2554.27s/it]

Building the model


 30%|███       | 3/10 [1:26:18<3:32:37, 1822.45s/it]

Building the model


 40%|████      | 4/10 [1:28:14<2:11:02, 1310.47s/it]

Building the model


 50%|█████     | 5/10 [1:30:05<1:19:14, 950.86s/it] 

Building the model


 60%|██████    | 6/10 [1:31:59<46:38, 699.56s/it]  

Building the model


 70%|███████   | 7/10 [1:33:50<26:09, 523.28s/it]

Building the model


 80%|████████  | 8/10 [2:15:04<36:56, 1108.49s/it]

Building the model


 90%|█████████ | 9/10 [2:16:59<13:30, 810.19s/it] 

Building the model


100%|██████████| 10/10 [2:18:52<00:00, 833.22s/it]


### Loading of the pretrained model

In [10]:
print(accuracy)

{'arhitecture': 0.7681159420289855, 'business': 0.7598343685300207, 'crime': 0.7184265010351967, 'education': 0.7329192546583851, 'games': 0.7204968944099379, 'literature': 0.7329192546583851, 'music': 0.7370600414078675, 'politics': 0.7494824016563147, 'sport': 0.7329192546583851, 'travel': 0.7329192546583851}


In [12]:
open('topic_accuracy_results.txt', 'w').write(str(accuracy))
open('topic_trigger_words_results.txt', 'w').write(str(trigger_words))

432122

In [20]:
for key, value in trigger_words.items():
    print(key, '\n', [pair[0] for pair in value][:100])

arhitecture 
 ['.', ',', '[SEP]', '[CLS]', '-', 'в', '>', '<', '"', 'росс', 'заявил', ')', ':', 'я', 'мы', '(', '##еи', '##и', ';', '/', 'и', 'на', 'br', 'вы', '##ои', '!', '?', 'сегодня', 'котор', 'президент', 'москв', '##ии', 'вам', 'позволяет', 'россия', 'не', 'этом', '##p', 'это', 'с', 'будет', 'что', 'р', 'как', 'меня', 'украин', 'сказал', 'европ', 'для', 'можно', '##а', 'деи', 'по', 'компании', '#', 'президента', 'ga', 'компания', '[', 'а', '«', 'словам', '2', 'модели', '»', 'сш', '*', 'ее', 'сообщает', 'се', 'нас', 'пресс', 'сообщил', 'новости', '##ев', '—', 'напомним', 'вас', '##ф', '1', 'все', 'он', 'отметил', 'технология', 'просто', '##мите', 'года', 'очень', '3', 'является', '##ичас', 'можете', 'кажд', '##ическо', 'мне', 'будут', 'са', '2011', '2013', 'наша']
business 
 ['.', ',', '[SEP]', '[CLS]', 'в', '"', '-', 'росс', ':', 'и', '<', '##и', 'президент', 'заявил', '>', '(', ';', ')', '/', 'я', 'москв', 'мы', '!', '##еи', 'не', 'на', 'что', 'можно', 'br', 'сегодня', '?', '['