In [2]:
import os
import torch
import torch.nn as nn
import numpy as np
from pandas import *
import torch.nn.functional as F
import dictionary_corpus
from torch.autograd import Variable
from collections import defaultdict

import transformers
import json

In [2]:
torch.manual_seed(1111)
np.random.seed(1111)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [11]:
# Path to model file
fn = "../../models/test-56-2095.pt" # Change!
model_name = "GPT2-Wikipedia"

In [9]:
config = transformers.AutoConfig.from_pretrained("../../models/")

model = transformers.AutoModelForCausalLM.from_config(config)

In [15]:
model.eval()

state_dict = torch.load(fn, map_location=torch.device('cpu'))["model"]
model.load_state_dict(state_dict)

  state_dict = torch.load(fn, map_location=torch.device('cpu'))["model"]


<All keys matched successfully>

In [18]:
# path to data files
data_path = "../../data/"

In [19]:
dictionary = dictionary_corpus.Dictionary(data_path)

In [20]:
def check_vocab(word_list):
    """
    Check if elements from word_list are in the model's vocab
    """
    unknown = set()
    for w in word_list:
        try:
            idx = dictionary.word2idx[w]
        except KeyError:
            unknown.add(w)
    print(unknown)
    print(len(unknown), "word(s) is/are not in the model's vocabulary")

In [40]:
def sent_surprisal(prompt):
    """
    Assigns surprisal values to a sentence
    prompt: list with sentence tokens
    Returns a list with surprisal values for the sentence
    """
    # Sent has <bos> and <eos> tokens for which surprisal of 0 is assigned
    surprisal_arr = [0]  # surprisal for <bos> token already added
    indices = [dictionary.word2idx[w] if w in dictionary.word2idx
               else dictionary.word2idx["<unk>"]
               for w in prompt]
    indices = torch.tensor(indices, dtype=torch.long)
    output = model(indices.view(-1, 1)).logits  # one input at a time, thus batch_size = 1
    for position, next_word in enumerate(prompt[1:-1]):  # excluding actual surprisal for <bos> and <eos>
        current_word_scores = output[position].view(-1)  # the output vector corresponding to the current word
        current_word_probs = F.log_softmax(current_word_scores, dim=0) # (log) softmax the score to get probabilities
        next_word_prob = current_word_probs[dictionary.word2idx[next_word]] # get the prob of the true next word
        surprisal = next_word_prob*(-1) 
        surprisal_arr.append(surprisal.item())
    surprisal_arr.append(0)  # surprisal for <eos> 
    return surprisal_arr

In [32]:
def get_surprisal_values(data):
    """
    Get surprisal values for a 'word' column in a df
    Returns a list with surprisal values for the whole df
    """
    surprisal_values = []
    end_idx = data.loc[data['word'] == '<eos>'].index.to_list()  # list with idx of rows that contain <eos>
    end_idx = [-1,*end_idx]  # inserting -1 as the start index to get the first sentence right
    for i in range(len(end_idx)-1):
        sent_range = range(end_idx[i]+1, end_idx[i+1]+1)
        sent_words = data.iloc[sent_range]['word'].to_list()
        surprisal_arr = sent_surprisal(sent_words)
        for s in surprisal_arr:
            surprisal_values.append(s)
    return surprisal_values

In [23]:
def filename_from_dataset(dataset):
    """
    Get output filename from input filename by adding "_result"
    """
    result_filename = '../data/results/gpt2/' + os.path.basename(dataset)[:-4] + '_result.csv' 
    print(result_filename)
    return result_filename

In [34]:
def analyze_data(dataset):
    words = []
    data = read_csv(dataset, index_col=0)
    for index, row in data.iterrows():
        if row["word"] != "<bos>":
            words.append(row["word"])
    check_vocab(words)
    surprisal_values = get_surprisal_values(data)
    data["surprisal"] = surprisal_values
    data["dependency"] = "Wh"
    data["language"] = "English"
    result_fn = filename_from_dataset(dataset)
    data.to_csv(result_fn, encoding="utf-8-sig", index=False)

In [41]:
analyze_data('../data/test_sentences/eq_wh_en.csv')
analyze_data('../data/test_sentences/whether_wh_en.csv')
analyze_data('../data/test_sentences/subject_wh_en.csv')
analyze_data('../data/test_sentences/unbound_wh_en.csv')

set()
0 word(s) is/are not in the model's vocabulary
['<bos>', 'She', 'mentioned', 'that', 'the', 'designer', 'specified', 'that', 'the', 'shelf', 'should', 'be', 'mounted', 'in', 'the', 'hallway', 'as', 'soon', 'as', 'possible', '.', '<eos>']
[62, 0, 8044, 34, 3, 15296, 14711, 34, 3, 9566, 324, 281, 9963, 29, 3, 47754, 38, 3256, 38, 1745, 18, 19]
<unk>
['<bos>', 'He', 'mentioned', 'that', 'the', 'teacher', 'knew', 'that', 'the', 'food', 'was', 'purchased', 'before', 'the', 'meeting', 'in', 'fall', '.', '<eos>']
[62, 168, 8044, 34, 3, 6229, 11329, 34, 3, 2834, 138, 6734, 1699, 3, 115, 29, 5034, 18, 19]
<unk>
['<bos>', 'I', 'hear', 'that', 'they', 'had', 'to', 'explain', 'that', 'the', 'machine', 'worked', 'in', 'an', 'unusual', 'way', 'in', 'the', 'office', '.', '<eos>']
[62, 162, 24903, 34, 1754, 68, 70, 9822, 34, 3, 5545, 108, 29, 534, 4555, 2609, 29, 3, 4871, 18, 19]
<unk>
['<bos>', 'You', 'heard', 'that', 'the', 'driver', 'forgot', 'that', 'the', 'car', 'should', 'be', 'picked', 'u