In [None]:
import sys
sys.path.append('ext/LAMA')

In [2]:
from typing import List, Tuple, Union
from pathlib import Path

import pandas as pd
import json
from itertools import chain

## Create dataset

In [3]:
MASK = '[MASK]'

def should_include_emotion(e: str):
    if len(e.split()) != 1:
        return False
    elif e == 'none':
        return False

    return True

def create_dataset(csv_path: Union[str, Path]) -> List[Tuple[str, Tuple[str]]]:
    if isinstance(csv_path, str):
        csv_path = Path(csv_path)

    # Load dataframe
    df = pd.read_csv(csv_path)
    
    # Create personx dataframe
    df = df[['Event', 'Xemotion']].rename(columns={'Event': 'event', 'Xemotion': 'emotion'})
    df = df[df.event.str.contains("PersonX") & ~df.event.str.contains("___")]
    df.emotion = df.emotion.apply(lambda x: json.loads(x))
    
    data = []
    for event, rows in df.groupby('event'):
        masked_sentence = f'{event} and PersonX feels {MASK}.'
        emotions = tuple(set(e for e in chain.from_iterable(rows.emotion) if should_include_emotion(e)))
        if emotions:
            data.append((masked_sentence, emotions))
            
    return df, data

In [6]:
df, dataset = create_dataset('data/test.csv')

In [39]:
with open('test.jsonl', 'a') as f:
    for event, emotions in dataset:
        out = {
            'masked_sentences': [event],
            'obj_label': emotions[0],
            'sub_label': 'Squad'
        }
        json.dump(out, f)
        print(file=f)

## Evaluate dataset

In [19]:
from lama.modules import build_model_by_name
from lama.utils import print_sentence_predictions, load_vocab
import lama.options as options
import lama.evaluation_metrics as evaluation_metrics
import argparse

In [17]:
def evaluate(sentences, model):
    if len(sentences) > 2:
        print("WARNING: only the first two sentences in the text will be considered!")
        sentences = sentences[:2]

    original_log_probs_list, [token_ids], [masked_indices] = model.get_batch_generation([sentences], try_cuda=False)

    index_list = None
    filtered_log_probs_list = original_log_probs_list

    # rank over the subset of the vocab (if defined) for the SINGLE masked tokens
    if masked_indices and len(masked_indices) > 0:
        evaluation_metrics.get_ranking(filtered_log_probs_list[0], masked_indices, model.vocab, index_list=index_list)

    # prediction and perplexity for the whole softmax
    return print_sentence_predictions(original_log_probs_list[0], token_ids, model.vocab, masked_indices=masked_indices)

In [22]:
parser = options.get_eval_generation_parser()
args = parser.parse_args(['--lm', 'bert'])

model = build_model_by_name('bert', args)

Loading bert model...


In [23]:
sentences = ['The cat is on the [MASK].', 'This is a cool [MASK].']
for s in sentences:
    x = evaluate([s], model)


| Top10 predictions
0       phone               -2.345      
1       floor               -2.630      
2       ground              -2.968      
3       couch               -3.387      
4       move                -3.649      
5       roof                -3.651      
6       way                 -3.718      
7       run                 -3.757      
8       bed                 -3.802      
9       left                -3.965      

----------------------------------------------------------------------------------
index   token               log_prob    prediction          log_prob    rank@1000   
----------------------------------------------------------------------------------
0       [CLS]               -12.386     .                   -2.570      -1          
1       The                 -5.547      .                   -0.607      14          
2       cat                 -0.367      cat                 -0.367      0           
3       is                  -0.019      is                  -0

In [30]:
for i, (event, emotions) in enumerate(dataset):
    print("EVENT:", event)
    print("GOLD:", emotions)
    evaluate([event], model)
    print("*********************************")
    
    if i == 10:
        break

EVENT: PersonX accidentally burned and PersonX feels [MASK].
GOLD: ('painfull', 'scared', 'embarassed', 'angry')

| Top10 predictions
0       guilty              -2.584      
1       sick                -3.396      
2       trapped             -3.481      
3       dizzy               -3.562      
4       bad                 -3.596      
5       ill                 -3.802      
6       better              -3.846      
7       embarrassed         -4.138      
8       worse               -4.162      
9       weak                -4.183      

----------------------------------------------------------------------------------
index   token               log_prob    prediction          log_prob    rank@1000   
----------------------------------------------------------------------------------
0       [CLS]               -13.019     .                   -2.677      -1          
1       Person              -0.043      Person              -0.043      0           
2       ##X                 -0.001

# Statistics

In [9]:
df, dataset = create_dataset('data/test.csv')
df.emotion = df.emotion.apply(lambda x: len(x))