In [1]:
import pandas as pd
import os
import pickle
import torch
import random
import re
import numpy as np
random.seed(21)
from transformers import DistilBertForMaskedLM, DistilBertConfig, AutoTokenizer



In [2]:
year_list = [
    '1900_1909',
    '1910_1919',
    '1920_1929',
    '1930_1939',
    '1940_1949',
    '1950_1959',
    '1960_1969',
    '1970_1979',
    '1980_1989',
    '1990_1999',
    '2000_2009',
    '2010_2020'
]

In [3]:
country_list = ['china',
                'north korea',
                'south korea',
                'canada',
                'united kingdom',
                'germany']

In [4]:
concept_list = ['autocracy',
                'autocratic',
                'dictator',
                'dictatorship',
               'authoritarianism',
               'democracy']

In [5]:
def load_decade_model(year):
    config = DistilBertConfig.from_json_file('./crs_models/' + year + '/config.json')
    config.output_hidden_states = True
    config.output_attentions=True
    model = DistilBertForMaskedLM(config)
    model.load_state_dict(torch.load('./crs_models/' + year + '/pytorch_model.bin', map_location = torch.device("cuda")))
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained('./crs_models/' + year + '/')
    return model, tokenizer

In [6]:
def get_attentions(sents, term, model, tokenizer):
    top_words = []
    word_weights = []
    
    for sent in sents:
        avgs = []
        sent = sent.split(' ')        
        encoded = tokenizer(sent, truncation = True, return_tensors='pt', is_split_into_words=True)
        
        trunc_flag = None
        for t in term.split(' '):
            idx = sent.index(t)
            token_span = encoded.word_to_tokens(idx)
            if token_span == None:
                trunc_flag = True
                
        if trunc_flag == True:
            top_words.append(np.nan)
            word_weights.append(np.nan)
            continue
        
        with torch.no_grad():
            outputs = model(encoded.input_ids)
        attention = outputs[-1]        
        
        for t in term.split(' '):
            for layer in [3,4,5]:
                for head in range(11):
                    idx = sent.index(t)
                    token_span = encoded.word_to_tokens(idx)
                    span_start = token_span.start
                    span_end = token_span.end
                    avg_weight=attention[layer][0][head][range(span_start, span_end)].mean(dim=0)
                    for t2 in term.split(' '):
                        idx2 = sent.index(t2)
                        token_span = encoded.word_to_tokens(idx2)
                        span_start = token_span.start
                        span_end = token_span.end
                        idx3 = torch.tensor([range(span_start, span_end)])
                        avg_weight[idx3[:]] = 0
                    avgs.append(avg_weight)
                    
        mean = torch.mean(torch.stack(avgs), dim=0)
        max_weight = mean[1:-1].max().item()
        word_weights.append(max_weight)
        index = (mean == max_weight).nonzero(as_tuple=True)[0]
        top_word = encoded.token_to_word(index)
        top_words.append(sent[top_word])
            
    df = pd.DataFrame()
    df['top_word'], df['word_weight'] = top_words, word_weights
    return df

In [7]:
def process_attentions(sentence_dict, term_list, year_list, term_type):
    for term in term_list:
        for year in year_list:
            sents = sentence_dict[term][year]
            if sents != None:
                model, tokenizer = load_decade_model(year)
                df = get_attentions(sents, term, model, tokenizer)
                outname = term + '_' + year + '.csv'
                outdir = './crs_attention_words/' + term_type + '/' + term
                if not os.path.exists(outdir):
                    os.makedirs(outdir)
                fullname = os.path.join(outdir,outname)
                df.to_csv(fullname, index=False)

In [8]:
# LOAD COUNTRY SENTENCES

with open('./crs_sents/country_sents.pkl', 'rb') as handle:
    country_sentence_dict = pickle.load(handle)

In [10]:
process_attentions(country_sentence_dict, country_list, year_list, 'countries')

In [11]:
del country_sentence_dict

In [17]:
# LOAD CONCEPT SENTENCES

with open('./crs_sents/concept_sents.pkl', 'rb') as handle:
    concept_sentence_dict = pickle.load(handle)

In [18]:
process_attentions(concept_sentence_dict, concept_list, year_list, 'concepts')

In [None]:
del concept_sentence_dict