In [1]:
import argparse
import config
import inflect
import os
import pathlib
import re
import torch
import utils

import pandas as pd

from collections import defaultdict
from dataclasses import dataclass
from minicons import scorer
from minicons import utils as mu
from minicons import openai as mo
from torch.utils.data import DataLoader
from tqdm import tqdm


inflector = inflect.engine()

In [2]:
PATH = "/Users/kanishka/<OPENAIKEY>"
mo.register_api_key(PATH)

In [3]:
@dataclass
class AANN:
    article: str
    adjective: str
    numeral: str
    noun: str

    def __post_init__(self):
        self.string = re.sub(
            r"\s{2,}",
            " ",
            f"{self.article} {self.adjective} {self.numeral} {self.noun}",
        ).strip()


In [4]:
def parse_aann(string, pattern):
    tokens = string.split()
    adj_span = re.search(config.ADJ_PATTERN, pattern).group(0)
    num_span = re.search(config.NUM_PATTERN, pattern).group(0)

    adjs_idx = mu.find_pattern(adj_span.split(), pattern.split())
    nums_idx = mu.find_pattern(num_span.split(), pattern.split())

    parsed = AANN(
        tokens[0],
        " ".join(tokens[adjs_idx[0] : adjs_idx[1]]),
        " ".join(tokens[nums_idx[0] : nums_idx[1]]),
        " ".join(tokens[nums_idx[1] :]),
    )
    return parsed

In [5]:
def parse_instance(aann):
    return parse_aann(aann["construction"], aann["pattern"])


def construction_pieces(sentence, construction):
    left, right = mu.character_span(sentence, construction)
    return sentence[:left], sentence[left:right], sentence[right:]


def reconstruct(left, middle, right, left_only=False):
    if left_only:
        concat_pieces = [left, middle]
    else:
        concat_pieces = [left, middle, right]
    string = " ".join(concat_pieces).strip()
    return re.sub(r" {2,}", " ", string)


def left_only(sentence, construction):
    left, right = mu.character_span(sentence, construction)
    return sentence[:left].strip(), sentence[left:right]


def default_ann(aann):
    return AANN("", aann.numeral, aann.adjective, aann.noun)


def corrupt_order(aann):
    article = inflector.a(aann.numeral.split(" ")[0]).split(" ")[0]
    return AANN(article, aann.numeral, aann.adjective, aann.noun)


def corrupt_article(aann):
    return AANN("", aann.adjective, aann.numeral, aann.noun)


def corrupt_modifier(aann):
    article = inflector.a(aann.numeral.split(" ")[0]).split(" ")[0]
    return AANN(article, "", aann.numeral, aann.noun)


def corrupt_numeral(aann):
    return AANN(aann.article, aann.adjective, "", aann.noun)


def corrupt_noun_num(aann):
    noun = inflector.singular_noun(aann.noun.split(" ")[-1])
    return AANN(aann.article, aann.adjective, aann.numeral, noun)


# extractors implemented as aann corruptors.
def non_article_region(aann):
    return AANN("", aann.adjective, aann.numeral, aann.noun)


def numeral_noun_region(aann):
    return AANN("", "", aann.numeral, aann.noun)


def just_noun_region(aann):
    return AANN("", "", "", aann.noun)


def left_context(sentence, construction, token_span):
    candidate_spans = [it.span() for it in re.finditer(token_span, sentence)]
    if len(candidate_spans) == 1:
        selected_span = candidate_spans[0]
    else:
        try:
            construction_span = re.search(construction, sentence).span()
        except:
            construction_span = re.search(
                re.escape(construction), sentence
            ).span()
        selected_span = [
            cs
            for cs in candidate_spans
            if utils.belongingness(cs, construction_span)
        ][0]

    if sentence == construction == token_span:
        return "", sentence
    else:
        return (
            sentence[: selected_span[0] - 1],
            sentence[selected_span[0] : selected_span[1]],
        )


def segment(instances, extractor, corruptor=None, only_construction=False):
    full_length, prefixes, continuations = [], [], []
    for instance in instances:
        parsed = parse_instance(instance)
        if corruptor is not None:
            parsed = corruptor(parsed)
            left, construction, right = construction_pieces(
                instance["sentence"], instance["construction"]
            )
            sentence = reconstruct(left, parsed.string, right)
            construction = parsed.string
        else:
            sentence = instance["sentence"]
            construction = instance["construction"]

        predicted_item = extractor(parsed)

        if only_construction:
            sentence = construction

        p, c = left_context(sentence, construction, predicted_item.string)
        prefixes.append(p)
        continuations.append(c)
        full_length.append((p + " " + c).strip())
    return full_length, prefixes, continuations

EXTRACTORS = {
    'construction': lambda x: x,
    'non_article_region': non_article_region,
    'numeral_noun_region': numeral_noun_region,
    'just_noun_region': just_noun_region,
}

def compute_scores(model, data, batch_size=32, modifier=None, extractors=None):
    
    scores = defaultdict(list)

    batches = mu.get_batch(data, batch_size)    

    for batch in tqdm(batches):
        sequences, prefixes, continuations = segment(batch, lambda x: x, modifier)
        
        lm = mo.OpenAIQuery(model, sequences)
        lm.query()
        
        for extractor in extractors:
            seq, pref, extracted_continuations = segment(batch, EXTRACTORS[extractor], modifier)
#             print(extractor, seq)
            
            scores[extractor].extend(lm.conditional_score(extracted_continuations))
            
    scores = dict(scores)
    
    return scores

In [6]:
good = utils.read_csv_dict(f"../data/mahowald/aanns_good.csv")

In [7]:
EXTRACTORS_AND_MODIFIERS = {
    "default_ann": ('non_article_region', default_ann),
    "order_swap": ('non_article_region', corrupt_order),
    "no_article": ('non_article_region', corrupt_article),
    "no_modifier": ('numeral_noun_region', corrupt_modifier),
    "no_numeral": ('just_noun_region', corrupt_numeral),
}

In [8]:
results = {
    'idx': [aann['idx'] for aann in good]
}

In [11]:
s = compute_scores('text-davinci-003', good, 64, lambda x: x, list(EXTRACTORS.keys()))

for e, scores in s.items():
    results[f"{e}_score"] = scores
    
for em, (extractor, modifier) in EXTRACTORS_AND_MODIFIERS.items():
    extractor_list = ['construction', extractor]
    s2 = compute_scores('text-davinci-003', good, 64, modifier,  extractor_list)
    results[f"{em}_corruption_score"] = s2['construction']
    results[f"{em}_region_score"] = s2[extractor]

203it [01:33,  2.16it/s]
203it [01:40,  2.02it/s]
203it [01:49,  1.85it/s]
203it [01:34,  2.14it/s]
203it [01:47,  1.88it/s]
203it [01:36,  2.11it/s]


In [12]:
print({k: len(v) for k, v in results.items()})

{'idx': 12960, 'construction_score': 12960, 'non_article_region_score': 12960, 'numeral_noun_region_score': 12960, 'just_noun_region_score': 12960, 'default_ann_construction_score': 12960, 'default_ann_region_score': 12960, 'order_swap_construction_score': 12960, 'order_swap_region_score': 12960, 'no_article_construction_score': 12960, 'no_article_region_score': 12960, 'no_modifier_construction_score': 12960, 'no_modifier_region_score': 12960, 'no_numeral_construction_score': 12960, 'no_numeral_region_score': 12960}


In [13]:
results_prefix = f"../data/results/mahowald/text-davinci-003"

In [14]:
results_df = pd.DataFrame(results)

In [15]:
# results_df = results_df.rename(columns={"construction": "construction_score", 
#                            "non_article_region": "non_article_region_score", 
#                            "numeral_noun_region": "numeral_noun_region_score",
#                            "just_noun_region": "just_noun_region_score"})

In [16]:
results_df.to_csv(f"{results_prefix}.csv", index=False)