In [136]:
import csv
import glob
import re
import random
import utils

import numpy as np
import pandas as pd

from tqdm import tqdm
from collections import defaultdict, Counter
from transformers import AutoTokenizer, pipeline
from torch.utils.data import DataLoader
from minicons import scorer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("../../smolm/models/smolm-autoreg-bpe-babylm-1e-3/")

In [3]:
def read_file(path):
    """TODO: make read all"""
    return [i.strip() for i in open(path, encoding="utf-8").readlines() if i.strip() != ""]

def read_babylm(path):
    """TODO: make read all"""
    return [i.strip() for i in open(path, encoding="utf-8").readlines()]

In [51]:
babylm_aanns = utils.read_csv_dict("../../rawdata/babylm_data/babylm_100M/aanns/aann_data.csv")
openbooks_aanns = utils.read_csv_dict("../data/openbooks_aanns.csv")
openbooks_aanns_idx = [int(instance['sentence_idx']) for instance in openbooks_aanns]

babylm_aanns_modified = []
for ba in babylm_aanns:
    if ba['sentence_idx'] == "563794":
        ba['construction'] =  " a full 40mm"
    elif ba['sentence_idx'] == "4347718":
        ba['construction'] = "a fuckin' 30 days"
    else:
        ba = ba
    babylm_aanns_modified.append(ba)

In [139]:
openbooks = []
train_files = glob.glob('../../rawdata/books1/epubtxt/*.txt')
for file in tqdm(train_files):
    openbooks.extend(read_file(file))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17868/17868 [00:15<00:00, 1128.78it/s]


In [128]:
babylm = {}
# for path in glob.glob("../../rawdata/babylm_data/babylm_100M/*.train"):
#     babylm.extend(read_babylm(path))

for file in glob.glob("../../rawdata/babylm_data/postags_100M/*.train"):
    corpus = re.split(r"(/|.train)", file)[-3]
    sents = read_babylm(f"../../rawdata/babylm_data/babylm_100M/{corpus}.train")
    babylm[corpus] = sents

In [8]:
'''
for each babylm aann instance, get number of tokens. Then sample similar amounts of tokens from openbooks non aann lines.

Should I index all openbooks by number of tokens (within min max)?? Yeah probably..
'''

'\nfor each babylm aann instance, get number of tokens. Then sample similar amounts of tokens from openbooks non aann lines.\n\nShould I index all openbooks by number of tokens (within min max)?? Yeah probably..\n'

In [55]:
def count_tokens(strings, tok=tokenizer):
    strings = [strings] if isinstance(strings, str) else strings
    tokenized = tok(strings)['input_ids']
    return [len(t)-1 for t in tokenized]

In [12]:
lm = scorer.MaskedLMScorer("roberta-large")

In [21]:
unmasker = pipeline('fill-mask', model=lm.model, tokenizer=lm.tokenizer, device="cuda:2")

In [106]:
def unmask_unmask(results, inputs):
    sequences = []
    for result, input_sequence in zip(results, inputs):
        if len(result) != 5:
            best_token = result[0][0]['token_str']
            sequence = input_sequence.replace(lm.tokenizer.mask_token, best_token)
        else:
            sequence = result[0]['sequence']
        sequences.append(sequence)
    return sequences

In [115]:
# pre-unmask all babylm aanns and save

babylm_aann_replacements = defaultdict(dict)

babylm_dl = DataLoader(babylm_aanns, batch_size = 8)
for batch in tqdm(babylm_dl):
    inputs = [s.replace(c, lm.tokenizer.mask_token) for s,c in zip(batch['sentence'], batch['construction'])]
    unmasked = unmasker(inputs)
    sentences = unmask_unmask(unmasked, inputs)

    ids = [int(x) for x in batch['sentence_idx']]
    sources = [s for s in batch['source']]

    for source, idx, sentence in zip(sources, ids, sentences):
        babylm_aann_replacements[source][idx] = sentence

babylm_aann_replacements = dict(babylm_aann_replacements)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [00:16<00:00,  7.33it/s]


In [141]:
for ba in babylm_aanns:
    if ba['construction'] not in babylm[ba['source']][int(ba['sentence_idx'])]:
        print(ba)

In [134]:
replaced_corpus = []
tokens_lost = 0

for k,v in tqdm(babylm.items()):
    for i, utterance in enumerate(v):
        if i in babylm_aann_replacements[k].keys():
            replacement = babylm_aann_replacements[k][i]

            utterance_tokens = count_tokens(utterance)[0]
            replacement_tokens = count_tokens(replacement)[0]
            loss = utterance_tokens - replacement_tokens
            tokens_lost += loss
            
            replaced_corpus.append(replacement)
        else:
            replaced_corpus.append(utterance)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.17it/s]


In [142]:
len(replaced_corpus) == np.sum([len(v) for k,v in babylm.items()])

True

In [143]:
tokens_lost

3320

In [180]:
def roundup(x):
    return x if x % 1000 == 0 else x + 1000 - x % 1000

In [177]:
# append 3320 tokens worth of text to the replaced corpus from openbooks
excess_corpus = []
tokens_added = 0

# heuristic = only process on first rounded to nearest 1000 sents (guaranteed to have fewer the number of tokens of interest) 
upper_bound = roundup(tokens_lost)
for i, utterance in enumerate(tqdm(openbooks[:upper_bound])):
    if i not in openbooks_aanns_idx:
        tokens = tokenizer(utterance)['input_ids'][1:]
        added = []
        for t in tokens:
            if tokens_added <= tokens_lost:
                added.append(t)
                tokens_added += 1
            else:
                break
        string = tokenizer.decode(added).strip()
        if string != "":
            excess_corpus.append(string)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 4036.73it/s]


In [179]:
full_corpus = replaced_corpus + excess_corpus

In [187]:
with open("../../rawdata/babylm_data/babylm_100M/babylm_100M_no_aann_infilling_roberta_openbooks.txt", "w") as f:
    for line in full_corpus:
        f.write(f"{line}\n")

In [10]:
# openbooks_tokens = defaultdict(list)
# for batch in tqdm(openbooks_dl):
#     idxes, sents = batch[0]
#     idxes = idxes.tolist()
#     num_tokens = count_tokens(sents)
#     for i, nt in zip(idxes, num_tokens):
#         if i not in openbooks_aanns_idx:
#             openbooks_tokens[nt].append(i)

In [None]:
# random.seed(42)
# random.shuffle(openbooks_non_aanns)
# ---
# target_lengths = defaultdict(int)
# for instance in babylm_aanns:
#     counts = count_tokens(instance['sentence'])[0]
#     target_lengths[counts] += 1

# target_lengths = dict(target_lengths)
# target_length_space = []
# for k,v in target_lengths.items():
#     target_length_space.extend([k] * v)

# total = len(target_length_space)

# sampled = defaultdict(list)

# for i, utterrance in enumerate(pbar := tqdm(openbooks)):
#     pbar.set_description(f"Length: {len(target_length_space)} out of {total}")
#     if i not in openbooks_aanns_idx:
#         count = count_tokens(utterrance)[0]
#         if count in target_length_space:
#             sampled[count].append(i)
#             target_length_space.remove(count)
#         if len(target_length_space) == 0:
#             break

Length: 87 out of 970:   0%|                                                                                  | 2971/36768629 [00:01<5:30:12, 1855.65it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Length: 59 out of 970:   0%|                                                                                  | 9022/36768629 [00:04<5:12:17, 1961.80it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Length: 37 out of 970:   0%|                            