In [7]:
import fire, os, sys, random
import torch, json
from transformers import BertTokenizerFast, BertConfig, BertForMaskedLM
# from src.token_util import tokenizeData
from tqdm import tqdm

# Note: causal1 and causal2 are not different causal variables, rather different values of the same causal varible.
#   Bear this in mind when doing tasks with more causal variables.........

# If variables aren't split up by outcomes of causal variable, then it only has 'causal'.

def clean_data(data, data_name):
    '''
    For Linzen marvin_linzen and BERT tokenizer.
    Remove coordinating examples and verbs where both forms not in vocabulary.
    '''
    for k in data[0].keys(): assert k in {'other', 'causal', 'causal1', 'causal2', 'envir1'}
    if 'causal' in data[0].keys(): assert 'causal1' not in data[0].keys() and 'causal2' not in data[0].keys()
    if 'causal1' in data[0].keys() or 'causal2' in data[0].keys(): assert 'causal' not in data[0].keys() and 'causal1' in data[0].keys() and 'causal2' in data[0].keys()
    print("Length of data:", len(data))

    if data_name == 'marvin_linzen':
        # No coordination
        data = [x for x in data if 'coord' not in x['other']['sent_type']]
        print("Length of data with no coordination:", len(data))

    if data_name == 'marvin_linzen': pair_extractor = lambda ex: (ex['causal1']['trg_wd'], ex['causal1']['trg_wd_flip'])
    elif data_name in {'lgd', 'lgd_orig'}: pair_extractor = lambda ex: (ex['causal']['trg_wd'], ex['causal']['trg_wd_flip']) if ex['causal']['label'] else (ex['causal']['trg_wd_flip'], ex['causal']['trg_wd'])
    else: assert False

    # Count verb pairs
    verb_pairs = {}
    for x in data:
        pair = pair_extractor(x)
        verb_pairs[pair] = verb_pairs.get(pair, 0) + 1

    # Remove invalid verbs
    verb_pairs_dct = {}
    print('Verb pairs:')

    # temp_vocab_list = ["is", "are", "has", "have"]
    temp_vocab_list = ["swim", "swims", "smile", "smiles", "brings", "bring", "interests", "interest"]

    for x in verb_pairs:
        boole = x[0] not in temp_vocab_list and x[1] not in temp_vocab_list
        print('\t', x, boole, '- will remove sentences' if not boole else '', '-', verb_pairs[x])
        verb_pairs_dct[x] = boole

    data = [x for x in data if verb_pairs_dct[pair_extractor(x)]]
    print('Length of data after removing invalid verbs:', len(data))

    # Confirm there's no duplicate data w.r.t. the input sentence-output pair.
    has_causal_by_value = 'causal1' in data[0]
    sentwd_extractor = lambda ex, var: (ex[var]['mask'], ex[var]['trg_wd'])
    sents_causal1, sents_causal2 = {}, {}
    sents_causal1_without_vb, sents_causal2_without_vb = {}, {}
    for x in data:
        causal1_key = sentwd_extractor(x, 'causal1' if has_causal_by_value else 'causal')
        causal2_key = sentwd_extractor(x, 'causal2') if has_causal_by_value else None
        mask_causal1, mask_causal2 = causal1_key[0], causal2_key[0] if has_causal_by_value else None

        sents_causal1[causal1_key] = sents_causal1.get(causal1_key, 0) + 1
        sents_causal2[causal2_key] = sents_causal2.get(causal2_key, 0) + 1
        sents_causal1_without_vb[mask_causal1] = sents_causal1_without_vb.get(mask_causal1, 0) + 1
        sents_causal2_without_vb[mask_causal2] = sents_causal2_without_vb.get(mask_causal2, 0) + 1
    if has_causal_by_value: assert len(sents_causal1) == len(sents_causal2) and len(sents_causal1_without_vb) == len(sents_causal2_without_vb)
    if has_causal_by_value: assert len(sents_causal1_without_vb) == len(sents_causal2_without_vb)
    assert len(sents_causal1) == len(data), 'There is duplicate data!!!'

    if len(sents_causal1_without_vb) != len(data): print("WARNING: There are sentences that are the same input with [MASK],"
                                                    " but the overall target verb (i.e. prediction) is different. Number of unique sentences"
                                                    " up to verb (i.e. unique inputs) is", len(sents_causal1_without_vb))

    return data



def getMaskVectorsAtLayerForOneBatch(batch_tokens, batch_mask_idxes, model, layer):
    '''
    Run examples through LLM for one batch.
    '''
    assert layer == -1, 'only last layer for now'

    # Get embeddings by running through LLM
    attention_mask = (batch_tokens != 0).float().to(batch_tokens.device)
    with torch.no_grad():
        output = model(batch_tokens, attention_mask=attention_mask) # (batch_size, max_input_len, embed_dim)
    embeds = output['hidden_states'][layer]

    # Now select the vector corresponding to the mask token
    mask_embeds = torch.gather(embeds, 1, batch_mask_idxes.view(-1, 1, 1).expand(-1, 1, embeds.shape[
        2]))  # (batch_size, 1, embed_dim)
    mask_embeds = mask_embeds.squeeze(1)  # (batch_size, embed_dim)

    return mask_embeds


def getMaskVectorsAtLayer(tokens, mask_idxes, batch_size, model, layer, do_tqdm):
    # assert tokens.shape[0] % batch_size == 0 # For now, just use even data
    embeddings = torch.zeros(tokens.shape[0], 768, device=tokens.device)
    iterator = range(0, tokens.shape[0], batch_size)
    iterator = tqdm(iterator) if do_tqdm else iterator
    for start_idx in iterator:
        if not do_tqdm and random.random() < 0.25: print (start_idx, '/', tokens.shape[0]) # random to reduce pritning a bit
        embeds = getMaskVectorsAtLayerForOneBatch(tokens[start_idx:(
            start_idx+batch_size),], mask_idxes[start_idx:(start_idx+batch_size)], model, layer)
        embeddings[start_idx:(start_idx+batch_size),] = embeds
    return embeddings


In [8]:


def run(data_file,
        num_examples=None,
        batch_size=200,
        do_tqdm=True):
    UNCLEAN_DATA = ('' if 'data' in os.listdir() else '../') + 'data/'+data_file+'.json'
    ARTEFACT_PATH = ('' if 'pipeline_artefacts' in os.listdir() else '../') + 'pipeline_artefacts/' + data_file + '/'
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('DEVICE:', DEVICE)
    
    if not os.path.exists(ARTEFACT_PATH):
        os.makedirs(ARTEFACT_PATH)

    # Read and clean the data
    with open(UNCLEAN_DATA, 'r') as f:
        data = json.load(f)
    data = clean_data(data, data_file)[:num_examples] # See function for what it does
    
    json_object = json.dumps(data)
    with open(ARTEFACT_PATH+'s1_specific_data_clean_ma_lin.json', 'w', newline='') as f:
        f.write(json_object)
    print('Done!')

In [9]:
run("marvin_linzen",num_examples=None,batch_size=200,do_tqdm=True)

DEVICE: cuda
Length of data: 63130
Length of data with no coordination: 62510
Verb pairs:
	 ('laughs', 'laugh') True  - 3730
	 ('swims', 'swim') False - will remove sentences - 3730
	 ('smiles', 'smile') False - will remove sentences - 3730
	 ('is', 'are') True  - 24920
	 ('brings', 'bring') False - will remove sentences - 2000
	 ('interests', 'interest') False - will remove sentences - 2000
	 ('likes', 'like') True  - 5600
	 ('admires', 'admire') True  - 5600
	 ('hates', 'hate') True  - 5600
	 ('loves', 'love') True  - 5600
Length of data after removing invalid verbs: 51050
Done!
