In [1]:
import json
import re
import pathlib
import csv

from minicons import scorer
from tqdm import tqdm

In [2]:
# read jsonl
def read_jsonl(path):
    with open(path) as f:
        data = f.readlines()
    data = [json.loads(line) for line in data]
    return data

def find_and_split(sentence, target):
    if target == ".":
        search_query = "\."
    else:
        search_query = target
    search_results = list(re.finditer(search_query, sentence))[-1].span()
    return sentence[:search_results[0]].strip(), target

pipps = read_jsonl('../../data/pipps/materials.jsonl')

In [3]:
# model_name = "kanishka/smolm-autoreg-bpe-babylm-1e-3"
model_name = "gpt2-large"

lm = scorer.IncrementalLMScorer(model_name, "cuda:3")

model_name = (
        model_name.replace("../smolm/models/", "")
        .replace("kanishka/", "")
        .replace("/", "_")
    )

model_name

'gpt2-large'

In [4]:
lm.conditional_score(['the keys to the cabinet']*2, ['is', 'are'], reduction=lambda x: -x.mean(0).item())
# lm.conditional_score()

[5.080320358276367, 2.4086380004882812]

In [5]:
for pipp in tqdm(pipps):
    for key in pipp:
        if key not in ['idx', 'embedding', 'preposition']:
            sentence = pipp[key]['sentence']
            target = pipp[key]['target']
            prefix, query = find_and_split(sentence, target)
            if target == ".":
                sep = ""
            else:
                sep = " "
            score = lm.conditional_score(prefix, query, separator=sep, reduction = lambda x: -x.mean(0).item(), base_two=True)
            pipp[key]['score'] = score[0]

100%|██████████| 198/198 [00:15<00:00, 12.61it/s]


In [6]:
def write_to_csv(results, filename):
    with open(filename, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['idx', 'preposition', 'embedding', 'pipp_filler_gap', 'pp_no_filler_no_gap', 'filler_no_gap', 'no_filler_gap'])
        for result in results:
            writer.writerow([result['idx'], result['preposition'], result['embedding'], result['pipp_filler_gap']['score'], result['pp_no_filler_no_gap']['score'], result['filler_no_gap']['score'], result['no_filler_gap']['score']])

In [7]:
pathlib.Path('../../data/results/pipps/').mkdir(parents=True, exist_ok=True)
write_to_csv(pipps, f'../../data/results/pipps/{model_name}.csv')

In [8]:
lm.token_score("Happy though we were with the idea, we decided to move on.", base_two=True)

[[('Happy', 0.0),
  ('though', -14.23239517211914),
  ('we', -5.657886505126953),
  ('were', -3.171238899230957),
  ('with', -6.144097805023193),
  ('the', -2.1990251541137695),
  ('idea', -7.152507781982422),
  (',', -2.537522792816162),
  ('we', -1.4758175611495972),
  ('decided', -5.810355186462402),
  ('to', -0.955513596534729),
  ('move', -6.110415458679199),
  ('on', -2.3879406452178955),
  ('.', -2.4258055686950684)]]

In [37]:
(-12.507351875305176 -0.00291132228448987)/2

-6.255131598794833

In [13]:
pipps[32]

{'idx': 32,
 'preposition': 'though',
 'embedding': '',
 'pipp_filler_gap': {'sentence': 'The vacationers emphasized that the vacation was fun, frantic though it may have seemed.',
  'target': '.',
  'score': 1.3217829465866089},
 'pp_no_filler_no_gap': {'sentence': 'The vacationers emphasized that the vacation was fun, though it may have seemed frantic.',
  'target': 'frantic',
  'score': 15.600930213928223},
 'filler_no_gap': {'sentence': 'The vacationers emphasized that the vacation was fun, frantic though it may have seemed frantic.',
  'target': 'frantic',
  'score': 23.60569190979004},
 'no_filler_gap': {'sentence': 'The vacationers emphasized that the vacation was fun, though it may have seemed.',
  'target': '.',
  'score': 12.073512077331543}}

In [18]:
lm.token_score("The vacationers emphasized that the vacation was fun, frantic though it may have seemed frantic.", base_two=True)

[[('The', 0.0),
  ('vacation', -16.906583786010742),
  ('ers', -5.41076135635376),
  ('emphasized', -15.571462631225586),
  ('that', -1.1680978536605835),
  ('the', -2.063772439956665),
  ('vacation', -9.088308334350586),
  ('was', -3.983067274093628),
  ('fun', -8.611307144165039),
  (',', -1.6558159589767456),
  ('frantic', -18.40604591369629),
  ('though', -12.963072776794434),
  ('it', -0.2623327076435089),
  ('may', -2.8368868827819824),
  ('have', -0.9901316165924072),
  ('seemed', -5.349593162536621),
  ('frantic', -23.60569190979004),
  ('.', -1.239820122718811)]]