This colab trains a subject-object classifier probe.



In [None]:
!pip install conllu

Collecting conllu
  Downloading conllu-4.5.3-py2.py3-none-any.whl (16 kB)
Installing collected packages: conllu
Successfully installed conllu-4.5.3


In [None]:
%%writefile create_dataset.py

"""
Run one iteration of the experiment, training on one language and testing on another.
"""
import argparse
import csv
import json
import numpy as np
import os
import pandas as pd
import pickle
import sys
import torch
from transformers import AutoTokenizer, AutoModel

from utils import get_tokens_and_labels, get_tokens_and_labels_csv, get_bert_tokens, shuffle_positions, save_sample, save_bert_outputs, save_just_position_word

base_path = "content"
def __main__():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ud-path', type=str,
        default=None)
    parser.add_argument('--csv-file', type=str, default=None, help="If data is in a csv file with subjects and objects marked")
    parser.add_argument('--bert-name', type=str, help="Like 'bert-base-uncased'")
    parser.add_argument('--shuffle-positions', action="store_true")
    parser.add_argument('--single-position', type=int, default=-1,
        help="Make all positions this one index")
    parser.add_argument('--local-shuffle', type=int, default=-1)
    args = parser.parse_args()
    print("args:", args)
    make_dataset(args)

def make_dataset(args):
    if args.ud_path is not None:
        tb_name = os.path.split(args.ud_path)[1]
        tb_name = os.path.splitext(tb_name)[0]
        directory = os.path.join(base_path, f"{tb_name}_{args.bert_name}")
    if args.csv_file is not None:
        dataset_name = os.path.split(args.csv_file)[1]
        dataset_name = os.path.splitext(dataset_name)[0]
        directory = os.path.join(base_path, f"{dataset_name}_{args.bert_name}")
    if args.single_position >= 0:
        directory += f"_singlepos{args.single_position}"
    elif args.shuffle_positions:
        directory += f"_shuffled-pos"
    elif args.local_shuffle >= 0:
        directory += f"_localshuffle{args.local_shuffle}"
    #os.mkdir(directory)
    directory = "/content"
    tokenizer = AutoTokenizer.from_pretrained(args.bert_name)
    model = AutoModel.from_pretrained(args.bert_name, output_hidden_states=True)
    model.eval()
    if args.ud_path is not None:
        labels = get_tokens_and_labels(args.ud_path)
    elif args.csv_file is not None:
        labels = get_tokens_and_labels_csv(args.csv_file)
    labels = shuffle_positions(labels, args.shuffle_positions, args.local_shuffle)
    json.dump(labels, open(os.path.join(directory, "labels.json"), "w"))
    save_sample(20, labels, directory)
    bert_info = {}
    bert_info["bert_tokens"], bert_info["bert_ids"], \
    bert_info["orig_to_bert_map"], bert_info["bert_to_orig_map"] =\
        get_bert_tokens(labels["token"], tokenizer)
    pickle.dump(bert_info, open(os.path.join(directory, "bert_info.pkl"), "wb"))
    bert_vectors_path = os.path.join(directory, "bert_vectors.hdf5")
    save_bert_outputs(directory, bert_info["bert_ids"], model, args.shuffle_positions, args.single_position)

if __name__ == "__main__":
    __main__()

Writing create_dataset.py


In [None]:
%%writefile create_index.py

import argparse
from collections import defaultdict
import json
import os
import pickle
from random import shuffle
import sys
import torch

from utils import load_embeddings

base_path = "/content"

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--roles', nargs='+', type=str, help="Roles, like A or S-passive. Should be capitalized correctly")
    parser.add_argument('--cases', nargs='+', type=str, help="Cases, like Erg or Nom. Should be capitalized correctly")
    parser.add_argument('--balance', action="store_true")
    parser.add_argument('--only-non-prototypical', action="store_true")
    parser.add_argument('--limit', type=int, default=-1)
    args = parser.parse_args()
    print(args)
    create_index(args)

def create_index(args):
    directory = "/content" #os.path.join(base_path, args.dataset)
    labels = json.load(open(os.path.join(directory, "labels.json"), "r"))
    index = []
    if args.roles:
        role_index = dict([(role, []) for role in args.roles])
    else:
        role_index = defaultdict(list)

    plain_index_name = "index"
    if args.balance:
        plain_index_name += "_balance"
    if args.roles:
        plain_index_name += "_roles-" + "".join(args.roles)
    if args.cases:
        plain_index_name += "_cases-" + "".join(args.cases)
    if args.limit > 0:
        plain_index_name += f"limit-{args.limit}"

    if args.only_non_prototypical:
        dataset_directory = os.path.join("dataset_storing", args.dataset)
        word_embeddings = load_embeddings(dataset_directory, "word_embeddings")
        orig_to_bert_map = \
            pickle.load(open(os.path.join(dataset_directory, "bert_info.pkl"), "rb"))["orig_to_bert_map"]
        classifier_dir = os.path.join("classifiers", args.dataset, plain_index_name)
        classifier, labelset, labeldict = pickle.load(
            open(os.path.join(classifier_dir, f"mlp_layer-word_embeddings"), "rb"))
        A_index = labeldict["A"]
        filename = plain_index_name + "only-non-prototypical.json"
    else:
        filename = plain_index_name + ".json"

    for sent_i in range(len(labels["token"])):
        for word_i in range(len(labels["token"][sent_i])):
            role = labels["role"][sent_i][word_i]
            case = labels["case"][sent_i][word_i] if "case" in labels else None
            role_ok = args.roles is None or role in args.roles
            role_ok = role_ok and role is not None
            case_ok = args.cases is None or case in args.cases
            if role_ok and case_ok:
                if not args.only_non_prototypical or \
                   check_non_prototypical(sent_i, word_i, word_embeddings, orig_to_bert_map, classifier, A_index, role):
                    role_index[role].append((sent_i, word_i))
    if args.balance:
        min_role_len = min([len(role_index[role]) for role in role_index.keys()])
        if args.limit > 0:
            if min_role_len * len(role_index.keys()) >= args.limit:
                min_role_len = args.limit // len(role_index.keys())
            else:
                print(f"Please pick a limit which is less than the balanced length. Limit = {args.limit}, min_role_len = {min_role_len} for roles {role_index.keys()}")
                sys.exit(1)
        print(f"Culling all roles to have length {min_role_len}")
        for role in role_index.keys():
            shuffle(role_index[role])
            index.extend(role_index[role][:min_role_len])
    else:
        if args.limit > 0:
            print(f"Limit not implemented for unbalanced index yet!")
            sys.exit(1)
        for role in role_index.keys():
            index.extend(role_index[role])
    json.dump(index, open(os.path.join(directory, filename), "w"))
    print("Index has length", len(index))

def check_non_prototypical(sent_i, word_i, word_embeddings, orig_to_bert_map, classifier, A_index, role):
    bert_start_index = orig_to_bert_map[sent_i][word_i]
    word_embedding = word_embeddings[sent_i].squeeze()[bert_start_index]
    classifier_output = classifier(torch.Tensor(word_embedding))
    probs = torch.softmax(classifier_output, 0)
    A_prob = probs[A_index].item()
    if role == "A" and A_prob < 0.5:
        print("A", A_prob)
        return True
    elif role == "O" and A_prob > 0.5:
        print("O", A_prob)
        return True
    else:
        return False

if __name__ == "__main__":
    main()

Writing create_index.py


In [None]:
%%writefile train_classifiers.py

import argparse
import json
import os
import pickle

from data import SimpleDataset
from utils import get_num_layers, train_classifier

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset-name', type=str)
    parser.add_argument('--index-name', type=str)
    parser.add_argument("--classifier-type", type=str, default="mlp")
    args = parser.parse_args()
    print("args", args)

    train_classifiers(args)

def train_classifiers(args):
    classifier_dir = os.path.join("classifiers", args.dataset_name, args.index_name)
    print("making classifier dir at", classifier_dir)
    if not os.path.exists(classifier_dir):
      os.makedirs(classifier_dir)
    num_layers = get_num_layers(args.dataset_name)
    print(f"There are {num_layers} layers in this model")
    layers = ["word_embeddings"] +  [str(i) for i in range(num_layers + 1)]
    logistic = args.classifier_type == "logistic"
    for layer in layers:
        print(f"Layer {layer}")
        dataset = SimpleDataset(args.dataset_name, args.index_name, layer)
        classifier = train_classifier(dataset, logistic)
        pickle.dump((classifier, dataset.labelset, dataset.labeldict),
            open(os.path.join(classifier_dir, f"{args.classifier_type}_layer-{layer}"), "wb"))

main()

Writing train_classifiers.py


In [None]:
%%writefile utils.py

from collections import defaultdict
import conllu
import csv
import h5py
import numpy as np
import os
import random
import sys
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import BertModel

def get_tokens_and_labels(filename):
    """
    Parameters:
    filename: te location of the treebank (conll file)

    This function parses the conll file to get:
    - labels: A dict, whose keys are types of labels (eg, "animacy"), and each
        value is a list of length num_sentences
    """
    with open(filename) as f:
        conll_data = f.read()
    sentences = conllu.parse(conll_data)
    labels = defaultdict(list)
    num_nouns = 0
    num_relevant_examples = 0
    for sent_i, tokenlist in enumerate(sentences):
        sentence_info = defaultdict(list)
        if "sent_id" in tokenlist.metadata.keys():
            sentence_info["sent_id"] = [tokenlist.metadata["sent_id"]]*len(tokenlist)
        noun_count = 0
        for token in tokenlist:
            token_info = get_token_info(token, tokenlist)
            token_case = None
            token_animacy = ""
            if token_info["role"] is not None:
                if token['feats'] and 'Case' in token['feats']:
                    token_case = token['feats']['Case']
                if token['feats'] and 'Animacy' in token['feats']:
                    token_animacy = token['feats']['Animacy']
            token_info["case"] = token_case
            token_info["animacy"] = token_animacy
            sentence_info["token"].append(token['form'])
            for label_type in token_info.keys():
                sentence_info[label_type].append(token_info[label_type])
            sentence_info["preceding_nouns"].append(noun_count)
            if token["upostag"] == "NOUN" or token["upostag"] == "PROPN" or token["upostag"]=="PRON":
                noun_count += 1
        for label_type in sentence_info.keys():
            labels[label_type].append(sentence_info[label_type])
        labels["word_index"].append(list(range(len(sentence_info["token"]))))
        assert len(sentence_info["case"]) == len(sentence_info["role"]), \
               "Length of case and role should be the same for every sentence (though both lists can include Nones)"
    print("returning from get_tokens, the keys are", list(labels.keys()))
    return dict(labels)

def get_tokens_and_labels_csv(filename):
    labels = defaultdict(list)
    with open(filename, 'r') as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',')
        for row in reader:
            sentence = row['sentence'].split(" ")
            labels["token"].append(sentence)
            labels['sent_id'].append([row['sentence_id']]*len(sentence))
            subject_idx = int(row["subject_idx"])
            object_idx = int(row["object_idx"])
            roles = [None]*len(sentence)
            roles[subject_idx] = "A"
            roles[object_idx] = "O"
            labels["role"].append(roles)
            subject_words = [None]*len(sentence)
            subject_words[subject_idx] = row["subject"]
            subject_words[object_idx] = row["subject"]
            labels["subject_word"].append(subject_words)
            object_words = [None]*len(sentence)
            object_words[subject_idx] = row["object"]
            object_words[object_idx] = row["object"]
            labels["object_word"].append(object_words)
            verb_words = [None]*len(sentence)
            verb_words[subject_idx] = row["verb"]
            verb_words[object_idx] = row["verb"]
            labels["verb_word"].append(verb_words)
            labels["word_index"].append(list(range(len(sentence))))
    return labels

def get_token_info(token, tokenlist):
    token_info = {}
    token_info["role"] = None
    token_info["verb_word"] = ""
    token_info["verb_idx"] = -1
    token_info["subject_word"] = ""
    token_info["object_word"] = ""
    if not (token["upostag"] == "NOUN" or token["upostag"] == "PROPN"):
        return token_info

    head_id = token['head']
    head_list = tokenlist.filter(id=head_id)
    head_pos = None
    if len(head_list) > 0:
        head_token = head_list[0]
        if head_token["upostag"] == "VERB":
            head_pos = "verb"
            token_info["verb_word"] = head_token["lemma"]
            token_info["verb_idx"] = int(head_token["id"]) - 1
        elif head_token["upostag"] == "AUX":
            head_pos = "aux"
            token_info["verb_word"] = head_token["lemma"]
            token_info["verb_idx"] = int(head_token["id"]) - 1
        else:
            return token_info

    if "nsubj" in token['deprel']:
        token_info["subject_word"] = token['form']
        has_object = False
        has_expletive_sibling = False
        # 'deps' field is often empty in treebanks, have to look through
        # the whole sentence to find if there is any object of the head
        # verb of this subject (this would determine if it's an A or an S)
        for obj_token in tokenlist:
            if obj_token['head'] == head_id:
                if "obj" in obj_token['deprel']:
                    has_object = True
                    token_info["object_word"] = obj_token["form"]
                if obj_token['deprel'] == "expl":
                    has_expletive_sibling = True
        if has_expletive_sibling:
            token_info["role"] = "S-expletive"
        elif has_object:
            token_info["role"] = "A"
        else:
            token_info["role"] = "S"
        if "pass" in token['deprel']:
            token_info["role"] += "-passive"
    elif "obj" in token['deprel']:
        token_info["role"] = "O"
        token_info["object_word"] = token['form']
        for subj_token in tokenlist:
            if subj_token['head'] == head_id:
                if "subj" in subj_token['deprel']:
                    token_info["subject_word"] = subj_token['form']
    if head_pos == "aux" and token_info["role"] is not None:
        token_info["role"] += "-aux"
    return token_info

def get_bert_tokens(orig_tokens, tokenizer):
    """
    Given a list of sentences, return a list of those sentences in BERT tokens,
    and a list mapping between the indices of each sentence, where
    bert_tokens_map[i][j] tells us where in the list bert_tokens[i] to find the
    start of the word in sentence_list[i][j]
    The input orig_tokens should be a list of lists, where each element is a word.
    """
    bert_tokens = []
    orig_to_bert_map = []
    bert_to_orig_map = []
    for i, sentence in enumerate(orig_tokens):
        sentence_bert_tokens = []
        sentence_map_otb = []
        sentence_map_bto = []
        sentence_bert_tokens.append("[CLS]")
        for orig_idx, orig_token in enumerate(sentence):
            sentence_map_otb.append(len(sentence_bert_tokens))
            tokenized = tokenizer.tokenize(orig_token)
            for bert_token in tokenized:
                sentence_map_bto.append(orig_idx)
            sentence_bert_tokens.extend(tokenizer.tokenize(orig_token))
        sentence_map_otb.append(len(sentence_bert_tokens))
        sentence_bert_tokens = sentence_bert_tokens[:511]
        sentence_bert_tokens.append("[SEP]")
        bert_tokens.append(sentence_bert_tokens)
        orig_to_bert_map.append(sentence_map_otb)
        bert_to_orig_map.append(sentence_map_bto)
    bert_ids = [tokenizer.convert_tokens_to_ids(b) for b in bert_tokens]
    return bert_tokens, bert_ids, orig_to_bert_map, bert_to_orig_map

def shuffle_positions(labels, shuffle_positions, local_shuffle):
    if not shuffle_positions and local_shuffle <= 0:
        print("Not shuffling positions this time")
        return labels
    assert not shuffle_positions or local_shuffle <= 0, \
        "Must choose between local and global shuffling!"
    labels["shuffled_index"] = []

    for sent_i, sentence in enumerate(labels["token"]):
        length = len(sentence)
        if shuffle_positions:
            permutation = list(range(length))
            random.shuffle(permutation)
        elif local_shuffle > 0:
            permutation = list(range(length))
            for chunk_start in range(0, length, local_shuffle):
                chunk_end = min(chunk_start + local_shuffle, length)
                chunk = permutation[chunk_start:chunk_end]
                random.shuffle(chunk)
                permutation[chunk_start:chunk_end] = chunk
        for label in labels:
            if label is not "shuffled_index":
                labels[label][sent_i] = \
                    [labels[label][sent_i][permutation[word_i]] for word_i in range(length)]
        labels["shuffled_index"].append(list(range(length)))
    return labels

def save_sample(num_samples, labels, directory):
    samples = []
    sampled_sentences = random.sample(range(len(labels['token'])), num_samples)
    for sent_i in sampled_sentences:
        sentence = " ".join(labels["token"][sent_i])
        sentence_id = labels["sent_id"][sent_i][0]
        samples.append([sentence_id, sentence])
    with open(os.path.join(directory, "sample.csv"), "w") as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(["sentence_id", "sentence"])
        writer.writerows(samples)

def save_bert_outputs(directory, bert_ids, bert_model, shuffle_positions=False, single_position=-1):
    """
    Given a list of lists of bert IDs, runs them through BERT.
    Cache the results to hdf5_path, and load them from there if available.
    """
    assert not shuffle_positions or not single_position >= 0, \
        "Choose beetween shuffling and putting a single position"

    datafile = h5py.File(os.path.join(directory, "bert_vectors.hdf5"), 'w')
    word_file = h5py.File(os.path.join(directory, "bert_word_embs.hdf5"), 'w')
    position_file = h5py.File(os.path.join(directory, "bert_position_embs.hdf5"), 'w')
    with torch.no_grad():
        print(f"Running {len(bert_ids)} sentences through BERT. This takes a while")
        for idx, sentence in enumerate(tqdm(bert_ids)):
            if single_position >= 0:
                positions = torch.ones((1, len(sentence)), dtype=torch.long) * single_position
            else:
                positions = torch.tensor(range(len(sentence))).unsqueeze(0)

            bert_output = bert_model(torch.tensor(sentence).unsqueeze(0),
                                     position_ids = positions)
            hidden_layers = bert_output["hidden_states"]
            layer_count = len(hidden_layers)
            _, sentence_length, dim = hidden_layers[0].shape
            dset = datafile.create_dataset(str(idx), (layer_count, sentence_length, dim))
            dset[:, :, :] = np.vstack([np.array(x) for x in hidden_layers])

            word_embedding = bert_model.embeddings.word_embeddings(torch.tensor(sentence))
            sentence_length, dim = word_embedding.shape
            word_dset = word_file.create_dataset(str(idx), (sentence_length, dim))
            word_dset[:,:] = word_embedding

            position_embedding = bert_model.embeddings.position_embeddings(positions)
            pos_dset = position_file.create_dataset(str(idx), (sentence_length, dim))
            pos_dset[:,:] = position_embedding
    datafile.close()
    word_file.close()
    position_file.close()

def save_just_position_word(directory, bert_ids, bert_model, shuffle_positions=False, single_position=-1):
    """
    NOTE: NOT USED. save_bert_outputs includes this functionality.
    Given a list of lists of bert IDs, runs them through BERT.
    Cache the results to hdf5_path, and load them from there if available.
    """
    assert not shuffle_positions or not single_position >= 0, \
        "Choose beetween shuffling and putting a single position"

    word_file = h5py.File(os.path.join(directory, "bert_word_embs.hdf5"), 'w')
    position_file = h5py.File(os.path.join(directory, "bert_position_embs.hdf5"), 'w')
    with torch.no_grad():
        print(f"Running {len(bert_ids)} sentences through BERT. This takes a while")
        for idx, sentence in enumerate(tqdm(bert_ids)):
            if single_position >= 0:
                positions = torch.ones((1, len(sentence)), dtype=torch.long) * single_position
            elif shuffle_positions:
                # Shuffle positions of everything except for first and last BERT tokens.
                positions = torch.arange(len(sentence), dtype=torch.long)
                permutation = torch.randperm(len(sentence)-2)
                positions[1:-1] = positions[1:-1][permutation]
                positions = positions.unsqueeze(0)
            else:
                positions = torch.tensor(range(len(sentence))).unsqueeze(0)

            word_embedding = bert_model.embeddings.word_embeddings(torch.tensor(sentence))
            sentence_length, dim = word_embedding.shape
            word_dset = word_file.create_dataset(str(idx), (sentence_length, dim))
            word_dset[:,:] = word_embedding

            position_embedding = bert_model.embeddings.position_embeddings(positions)
            pos_dset = position_file.create_dataset(str(idx), (sentence_length, dim))
            pos_dset[:,:] = position_embedding
    word_file.close()
    position_file.close()

def load_bert_outputs(directory, layer):
    hdf5_path = os.path.join(directory, "bert_vectors.hdf5")
    try:
        layer = int(layer)
    except:
        print("Please use a valid layer in 0-12. If you want word embeddings, use the get_word_embeddings method")
    outputs = []
    try:
        with h5py.File(hdf5_path, 'r') as datafile:
            max_key = max([int(key) for key in datafile.keys()])
            for i in tqdm(range(max_key + 1), desc='[Loading from disk]'):
                hidden_layers = datafile[str(i)[:]]
                output = np.array(hidden_layers[layer])
                outputs.append(output)
            print(f"Loaded {i} sentences from disk.")
    except OSError:
        print(f"Encountered hdf5 reading error on file {hdf5_path}. Please re-create the hdf5 file")
        sys.exit(1)
    return outputs

def load_embeddings(directory, embeddings_type):
    if embeddings_type == "word_embeddings":
        hdf5_path = os.path.join(directory, "bert_word_embs.hdf5")
    elif embeddings_type == "position_embeddings":
        hdf5_path = os.path.join(directory, "bert_position_embs.hdf5")
    else:
        print(embeddings_type, "Is not not word_embeddings or position_embeddings")
        sys.exit(1)

    outputs = []
    try:
        with h5py.File(hdf5_path, 'r') as datafile:
            max_key = max([int(key) for key in datafile.keys()])
            for i in tqdm(range(max_key + 1), desc='[Loading from disk]'):
                outputs.append(np.array(datafile[str(i)][:]))
    except OSError:
        print(f"Encountered hdf5 reading error on file {hdf5_path}. Please re-create the hdf5 file")
        sys.exit(1)
    return outputs

def get_num_layers(dataset_name):
    #dataset_directory = os.path.join("dataset_storing", dataset_name)
    dataset_directory = "/content"
    hdf5_path = os.path.join(dataset_directory, "bert_vectors.hdf5")
    try:
        with h5py.File(hdf5_path, 'r') as datafile:
            hidden_layers = datafile[str(0)[:]]
            return hidden_layers.shape[0]
    except OSError:
        print(f"Encountered hdf5 reading error on file {hdf5_path}. Please re-create the hdf5 file")
        sys.exit(1)

class _classifier(nn.Module):
    def __init__(self, nlabel, bert_dim):
        super(_classifier, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(bert_dim, 64),
            nn.ReLU(),
            nn.Linear(64, nlabel),
            nn.Dropout(.1)
        )
    def forward(self, input):
        return self.main(input)

def train_classifier(dataset, logistic):
    if logistic:
        return train_classifier_logistic(dataset)
    else:
        return train_classifier_mlp(dataset)

def train_classifier_mlp(train_dataset, epochs=20):
    classifier = _classifier(train_dataset.get_num_labels(), train_dataset.get_bert_dim())
    optimizer = torch.optim.Adam(classifier.parameters())
    criterion = nn.CrossEntropyLoss()

    dataloader = train_dataset.get_dataloader()

    for epoch in range(epochs):
        losses = []
        for emb_batch, role_label_batch, _ in dataloader:
            output = classifier(emb_batch)
            loss = criterion(output, role_label_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.data.mean().item())
        print('[%d/%d] Train loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
    return classifier

def train_classifier_logistic(train_dataset):
    X, y = [], []
    dataloader = train_dataset.get_dataloader(batch_size=1)
    for emb_batch, role_label_batch, _ in dataloader:
        X.append(emb_batch[0])
        y.append(role_label_batch[0])
    X = np.stack(X, axis=0)
    y = np.stack(y, axis=0)
    scaler = preprocessing.StandardScaler().fit(X)
    X_scaled = scaler.transform(X)
    classifier = LogisticRegression(random_state=0, max_iter=10000).fit(X_scaled, y)
    return classifier

Writing utils.py


In [None]:
%%writefile data.py

from collections import defaultdict, Counter
import json
import numpy as np
import os
import pickle
import random
import sys
import torch
import torch.utils.data as data

from utils import load_bert_outputs, load_embeddings

class SimpleDataset(data.Dataset):
    def __init__(self, dataset_name, index_name, layer_num, old_labeldict = None, pool_method = "first"):
        self.layer_num = layer_num
        self.pool_method = pool_method
        #dataset_directory = os.path.join("dataset_storing", dataset_name)
        dataset_directory = "/content"
        self.labels = json.load(open(os.path.join(dataset_directory, "labels.json"), "r"))
        self.bert_info = pickle.load(open(os.path.join(dataset_directory, "bert_info.pkl"), "rb"))

        if layer_num in ["word_embeddings", "position_embeddings"]:
            self.bert_outputs = load_embeddings(dataset_directory, layer_num)
        else:
            try:
                print("in try", layer_num)
                int_layer = int(layer_num)
                self.bert_outputs = load_bert_outputs(dataset_directory, layer_num)
            except:
                print(f"Please put a valid layer name, {layer_num} is not a layer")
                sys.exit(1)
        self.index = json.load(open(os.path.join(dataset_directory, index_name + ".json"), "r"))
        print("Examples #", len(self.index))
        self.labeldict = self.get_label_dict(old_labeldict)
        self.labelset = sorted(self.labeldict.keys(), key=lambda x: self.labeldict[x])

    def __getitem__(self, idx):
        sentence_num, word_num = self.index[idx]
        bert_start_index = self.bert_info["orig_to_bert_map"][sentence_num][word_num]
        bert_end_index = self.bert_info["orig_to_bert_map"][sentence_num][word_num + 1]
        embedding = self.get_pooled_embedding(sentence_num, bert_start_index,
                                              bert_end_index)
        role = self.labels["role"][sentence_num][word_num]
        role_label_idx = self.labeldict[role] if role in self.labeldict else -1
        aux_labels = {}
        for label_type in self.labels.keys():
            label = self.labels[label_type][sentence_num][word_num]
            if label == None:
                label = ""
            aux_labels[label_type] = label
        #aux_labels["word_index"] = word_num
        return embedding, role_label_idx, aux_labels

    # Make a labeldict of all of the labels in this dataset, keeping the same
    # order for labels already in the old labeldict
    def get_label_dict(self, old_labeldict):
        all_labels = set()
        for sent_i, word_i in self.index:
            new_role = self.labels["role"][sent_i][word_i]
            if new_role is not None:
                all_labels.add(new_role)
        labelset = sorted(list(all_labels))
        if old_labeldict is None:
            curr_label = 0
            labeldict = {}
        else:
            labeldict = old_labeldict
            curr_label = len(old_labeldict)
        for label in labelset:
            if old_labeldict is None or label not in old_labeldict:
                labeldict[label] = curr_label
                curr_label += 1
        return labeldict

    def get_num_labels(self):
        return len(self.labeldict)

    def get_bert_dim(self):
        return self.bert_outputs[0].shape[1]

    def __len__(self):
        return len(self.index)

    def get_pooled_embedding(self, sentence_num, bert_start_index, bert_end_index):
        bert_sentence = \
            self.bert_outputs[sentence_num].squeeze()
        if self.pool_method == "first":
            return bert_sentence[bert_start_index]
        elif self.pool_method == "average":
            return np.mean(
                bert_outputs[sentence_num][self.layer_num].squeeze()\
                    [bert_start_index:bert_end_index])

    def get_dataloader(self, batch_size=32, shuffle=True):
      return data.DataLoader(self, batch_size=batch_size, shuffle=shuffle)

Writing data.py


In [None]:
!python /content/create_dataset.py --ud-path /content/concat-ud.conllu --bert-name bert-base-uncased

  if label is not "shuffled_index":
args: Namespace(ud_path='/content/concat-ud.conllu', csv_file=None, bert_name='bert-base-uncased', shuffle_positions=False, single_position=-1, local_shuffle=-1)
tokenizer_config.json: 100% 28.0/28.0 [00:00<00:00, 113kB/s]
config.json: 100% 570/570 [00:00<00:00, 2.55MB/s]
vocab.txt: 100% 232k/232k [00:00<00:00, 3.40MB/s]
tokenizer.json: 100% 466k/466k [00:00<00:00, 3.43MB/s]
model.safetensors: 100% 440M/440M [00:03<00:00, 129MB/s]
returning from get_tokens, the keys are ['sent_id', 'token', 'role', 'verb_word', 'verb_idx', 'subject_word', 'object_word', 'case', 'animacy', 'preceding_nouns', 'word_index']
Not shuffling positions this time
Running 10625 sentences through BERT. This takes a while
100% 10625/10625 [23:05<00:00,  7.67it/s]


In [None]:
!python /content/create_index.py --dataset concat-ud_bert-base-uncased --roles A O --balance

Namespace(dataset='concat-ud_bert-base-uncased', roles=['A', 'O'], cases=None, balance=True, only_non_prototypical=False, limit=-1)
Culling all roles to have length 1469
Index has length 2938


In [None]:
!python train_classifiers.py --dataset concat-ud_bert-base-uncased --index-name index_balance_roles-AO --classifier-type mlp

args Namespace(dataset_name='concat-ud_bert-base-uncased', index_name='index_balance_roles-AO', classifier_type='mlp')
making classifier dir at classifiers/concat-ud_bert-base-uncased/index_balance_roles-AO
There are 13 layers in this model
Layer word_embeddings
[Loading from disk]: 100% 10625/10625 [00:03<00:00, 3482.90it/s]
Examples # 2938
[1/20] Train loss: 0.681
[2/20] Train loss: 0.623
[3/20] Train loss: 0.575
[4/20] Train loss: 0.556
[5/20] Train loss: 0.530
[6/20] Train loss: 0.517
[7/20] Train loss: 0.513
[8/20] Train loss: 0.502
[9/20] Train loss: 0.492
[10/20] Train loss: 0.490
[11/20] Train loss: 0.473
[12/20] Train loss: 0.469
[13/20] Train loss: 0.458
[14/20] Train loss: 0.446
[15/20] Train loss: 0.438
[16/20] Train loss: 0.433
[17/20] Train loss: 0.421
[18/20] Train loss: 0.418
[19/20] Train loss: 0.405
[20/20] Train loss: 0.400
Layer 0
in try 0
[Loading from disk]: 100% 10625/10625 [00:18<00:00, 586.49it/s] 
Loaded 10624 sentences from disk.
Examples # 2938
[1/20] Train 

In [None]:
#Test on original and argument-swapped versions of our sentences

In [None]:
%%writefile eval_classifiers.py

import argparse
from collections import defaultdict
import datetime
import json
import numpy as np
import os
import pandas as pd
import pickle
import torch

from data import SimpleDataset

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-dataset', type=str)
    parser.add_argument('--train-index', type=str)
    parser.add_argument("--classifier-type", type=str, default="mlp")
    parser.add_argument('--eval-dataset', type=str)
    parser.add_argument('--eval-index', type=str)
    args = parser.parse_args()
    print("args", args)

    eval_classifiers(args)


def eval_classifiers(args):
    #classifier_dir = os.path.join("classifiers", args.train_dataset, args.train_index)
    classifier_dir = "/content/classifiers/concat-ud_bert-base-uncased/index_balance_roles-AO"
    print("Evaluating classifiers at", classifier_dir)
    #layers = ["position_embeddings", "word_embeddings"] +  [str(i) for i in range(13)]
    layers = ["word_embeddings"] +  [str(i) for i in range(13)]
    logistic = args.classifier_type == "logistic"
    results = defaultdict(list)
    for layer in layers:
        classifier, labelset, labeldict = pickle.load(
            open(os.path.join(classifier_dir, f"{args.classifier_type}_layer-{layer}"), "rb"))
        print("Classifier labeldict", labeldict)
        A_index = labeldict["A"]
        eval_dataset = SimpleDataset(args.eval_dataset, args.eval_index, layer, old_labeldict = labeldict)
        dataloader = eval_dataset.get_dataloader(batch_size=1)
        classifier.eval()
        for embedding, role, other_labels in dataloader:
            if args.classifier_type == "logistic":
                probs = classifier.predict_proba(torch.Tensor(embedding))[0]
                A_prob = probs[A_index]
            elif args.classifier_type == "mlp":
                output = classifier(torch.Tensor(embedding))
                probs = torch.softmax(output, 1)
                A_prob = probs[:,A_index][0].item()
            for label in other_labels.keys():
                val = other_labels[label][0]
                if type(val) == torch.Tensor:
                    val = val.item()
                results[label].append(val)
            results["layer"].append(layer)
            results["probability_A"].append(A_prob)
    df = pd.DataFrame(results)
    date_string = datetime.datetime.now().strftime("%m%d%Y")
    #output_file = f"{date_string}_train-{args.train_dataset}-{args.train_index}_eval-{args.eval_dataset}-{args.eval_index}.csv"
    #df.to_csv(open(os.path.join("results", "long_names", output_file), "w"))
    df.to_csv("/content/results-switched.csv")



def add_classifier_predictions(labels_file, vecs_file, classifier, label_dict, layer, classifier_type):
    vecs = h5py.File(vecs_file, "r")[f"bert_layer_{layer}"]
    labels = json.load(open(labels_file, "r"))
    print("after loading", labels.keys())
    length = len(vecs)
    labels[f"probA_{classifier_type}_layer_{layer}"] = [0]*length
    A_index = label_dict["A"]
    for i in range(length):
        if classifier_type == "logistic":
            probs = classifier.predict_proba(torch.Tensor(vecs[i].astype(np.float32)).unsqueeze(0))[0]
            A_prob = probs[A_index]
        elif classifier_type == "mlp":
            output = classifier(torch.Tensor(vecs[i].astype(np.float32)).unsqueeze(0))
            probs = torch.softmax(output, 1)
            A_prob = probs[:,A_index][0].item()
        labels[f"probA_{classifier_type}_layer_{layer}"][i] = A_prob
    print("before saving", labels.keys())
    json.dump(labels, open(labels_file, "w"))


if __name__ == "__main__":
    main()

Overwriting eval_classifiers.py


In [None]:
!python create_dataset.py --csv-file  /content/object-subject-original.csv --bert-name bert-base-uncased

args: Namespace(ud_path=None, csv_file='/content/object-subject-original.csv', bert_name='bert-base-uncased', shuffle_positions=False, single_position=-1, local_shuffle=-1)
Not shuffling positions this time
Running 53 sentences through BERT. This takes a while
100% 53/53 [00:07<00:00,  7.57it/s]


In [None]:
!python create_index.py --dataset object-subject-original_bert-base-uncased

Namespace(dataset='object-subject-original_bert-base-uncased', roles=None, cases=None, balance=False, only_non_prototypical=False, limit=-1)
Index has length 106


In [None]:
!python eval_classifiers.py --train-dataset concat-ud_bert-base-uncased --train-index index_balance_roles-AO --classifier-type mlp --eval-dataset object-subject-original_bert-base-uncased --eval-index index

args Namespace(train_dataset='concat-ud_bert-base-uncased', train_index='index_balance_roles-AO', classifier_type='mlp', eval_dataset='object-subject-original_bert-base-uncased', eval_index='index')
Evaluating classifiers at /content/classifiers/concat-ud_bert-base-uncased/index_balance_roles-AO
Classifier labeldict {'A': 0, 'O': 1}
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00, 3994.86it/s]
Examples # 106
Classifier labeldict {'A': 0, 'O': 1}
in try 0
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00, 4657.80it/s]
Loaded 52 sentences from disk.
Examples # 106
Classifier labeldict {'A': 0, 'O': 1}
in try 1
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00, 5708.15it/s]
Loaded 52 sentences from disk.
Examples # 106
Classifier labeldict {'A': 0, 'O': 1}
in try 2
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00

In [None]:
!python create_dataset.py --csv-file  /content/object-subject-switched.csv --bert-name bert-base-uncased


args: Namespace(ud_path=None, csv_file='/content/object-subject-switched.csv', bert_name='bert-base-uncased', shuffle_positions=False, single_position=-1, local_shuffle=-1)
Not shuffling positions this time
Running 53 sentences through BERT. This takes a while
100% 53/53 [00:07<00:00,  7.17it/s]


In [None]:
!python create_index.py --dataset object-subject-switched_bert-base-uncased

Namespace(dataset='object-subject-switched_bert-base-uncased', roles=None, cases=None, balance=False, only_non_prototypical=False, limit=-1)
Index has length 106


In [None]:
#adjust csv name in eval_classifiers before doing this
!python eval_classifiers.py --train-dataset concat-ud_bert-base-uncased --train-index index_balance_roles-AO --classifier-type mlp --eval-dataset object-subject-switched_bert-base-uncased --eval-index index

args Namespace(train_dataset='concat-ud_bert-base-uncased', train_index='index_balance_roles-AO', classifier_type='mlp', eval_dataset='object-subject-switched_bert-base-uncased', eval_index='index')
Evaluating classifiers at /content/classifiers/concat-ud_bert-base-uncased/index_balance_roles-AO
Classifier labeldict {'A': 0, 'O': 1}
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00, 4328.07it/s]
Examples # 106
Classifier labeldict {'A': 0, 'O': 1}
in try 0
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00, 4955.71it/s]
Loaded 52 sentences from disk.
Examples # 106
Classifier labeldict {'A': 0, 'O': 1}
in try 1
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00, 5650.26it/s]
Loaded 52 sentences from disk.
Examples # 106
Classifier labeldict {'A': 0, 'O': 1}
in try 2
[Loading from disk]:   0% 0/53 [00:00<?, ?it/s][Loading from disk]: 100% 53/53 [00:00<00:00