In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
from pandas import *
import torch.nn.functional as F
import dictionary_corpus
from torch.autograd import Variable
from collections import defaultdict

import transformers
import json

In [2]:
torch.manual_seed(1111)
np.random.seed(1111)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Path to model file
fn = "../data/lm/Norwegian/state-20-819.pt" # Change!
model_name = "GPT2-Wikipedia"

In [4]:
config = transformers.AutoConfig.from_pretrained("../data/lm/Norwegian/")

model = transformers.AutoModelForCausalLM.from_config(config)

OSError: Can't load the configuration of '../data/lm/Norwegian/'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '../data/lm/Norwegian/' is the correct path to a directory containing a config.json file

In [None]:
model.eval()

state_dict = torch.load(fn, map_location=torch.device('cpu'))["model"]
model.load_state_dict(state_dict)

In [None]:
# path to data files
data_path = "../data/lm/Norwegian/"

In [None]:
dictionary = dictionary_corpus.Dictionary(data_path)

In [None]:
def check_vocab(word_list):
    """
    Check if elements from word_list are in the model's vocab
    """
    unknown = set()
    for w in word_list:
        try:
            idx = dictionary.word2idx[w]
        except KeyError:
            unknown.add(w)
    print(unknown)
    print(len(unknown), "word(s) is/are not in the model's vocabulary")

In [None]:
def sent_surprisal(prompt):
    """
    Assigns surprisal values to a sentence
    prompt: list with sentence tokens
    Returns a list with surprisal values for the sentence
    """
    # Sent has <bos> and <eos> tokens for which surprisal of 0 is assigned
    surprisal_arr = [0]  # surprisal for <bos> token already added
    indices = [dictionary.word2idx[w] if w in dictionary.word2idx
               else dictionary.word2idx["<unk>"]
               for w in prompt]
    indices = torch.tensor(indices, dtype=torch.long)
    output = model(indices.view(1, -1)).logits[0]  # one input at a time, thus batch_size = 1
    for position, next_word in enumerate(prompt[1:-1]):  # excluding actual surprisal for <bos> and <eos>
        current_word_scores = output[position].view(-1)  # the output vector corresponding to the current word
        current_word_probs = F.log_softmax(current_word_scores, dim=0) # (log) softmax the score to get probabilities
        next_word_prob = current_word_probs[dictionary.word2idx[next_word]] # get the prob of the true next word
        surprisal = next_word_prob*(-1) 
        surprisal_arr.append(surprisal.item())
    surprisal_arr.append(0)  # surprisal for <eos> 
    return surprisal_arr

In [None]:
def get_surprisal_values(data):
    """
    Get surprisal values for a 'word' column in a df
    Returns a list with surprisal values for the whole df
    """
    surprisal_values = []
    end_idx = data.loc[data['word'] == '<eos>'].index.to_list()  # list with idx of rows that contain <eos>
    end_idx = [-1,*end_idx]  # inserting -1 as the start index to get the first sentence right
    for i in range(len(end_idx)-1):
        sent_range = range(end_idx[i]+1, end_idx[i+1]+1)
        sent_words = data.iloc[sent_range]['word'].to_list()
        surprisal_arr = sent_surprisal(sent_words)
        for s in surprisal_arr:
            surprisal_values.append(s)
    return surprisal_values

In [None]:
def filename_from_dataset(dataset):
    """
    Get output filename from input filename by adding "_result"
    """
    result_filename = '../data/results/gpt2/' + os.path.basename(dataset)[:-4] + '_result.csv' 
    print(result_filename)
    return result_filename

In [None]:
def analyze_data(dataset, dependency):
    data = read_csv(dataset, index_col=0)
    words = data['word'].tolist()
    check_vocab(words)
    surprisal_values = get_surprisal_values(data)
    data["surprisal"] = surprisal_values
    data["dependency"] = dependency
    data["language"] = "Norwegian"
    result_fn = filename_from_dataset(dataset)
    data.to_csv(result_fn, encoding="utf-8-sig", index=False)

In [None]:
analyze_data('../data/test_sentences/eq_wh.csv', "Wh")
analyze_data('../data/test_sentences/whether_wh.csv', "Wh")
analyze_data('../data/test_sentences/subject_wh.csv', "Wh")
analyze_data('../data/test_sentences/unbound_wh.csv', "Wh")

In [None]:
analyze_data('../data/test_sentences/eq_rc.csv', "RC")
analyze_data('../data/test_sentences/whether_rc.csv', "RC")
analyze_data('../data/test_sentences/subject_rc.csv', "RC")
analyze_data('../data/test_sentences/unbound_rc.csv', "RC")