In [1]:
import utils
import csv

from collections import Counter
from nltk.tokenize import TweetTokenizer
from joblib import Parallel, delayed
from multiprocessing import Manager
from tqdm import tqdm

In [2]:
tokenizer = TweetTokenizer()

In [3]:
tokens = Manager().list()

# tokenizer in parallel
def tokenize(text):
    tokens.extend(tokenizer.tokenize(text))

# read file
babylm = utils.read_file("/home/km55359/rawdata/babylm_data/babylm_100M/babylm_100M_train.txt")

# tokenize
# Parallel(n_jobs=32)(delayed(tokenize)(text) for text in tqdm(babylm))

def count_word_frequencies(sentence):
    words = tokenizer.tokenize(sentence.lower())
    return Counter(words)

def merge_counters(counters):
    result = Counter()
    for counter in counters:
        result.update(counter)
    return result

def count_word_frequencies_in_parallel(sentences):
    # Use joblib's Parallel and delayed to parallelize the word frequency counting
    counters = Parallel(n_jobs=-1)(delayed(count_word_frequencies)(sentence) for sentence in tqdm(sentences))
    
    # Merge individual counters into a single counter
    word_frequencies = merge_counters(counters)
    return word_frequencies

In [4]:
word_counts = count_word_frequencies_in_parallel(babylm)

100%|██████████| 10175732/10175732 [01:17<00:00, 131364.95it/s]


In [5]:
with open("../data/babylm-analysis/babylm-unigrams.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["word", "count"])
    for word, count in word_counts.most_common():
        writer.writerow([word, count])

In [6]:
mahowald_aanns = utils.read_csv_dict("../data/mahowald-aann/aanns_good.csv")
babylm_aanns = utils.read_csv_dict("../data/babylm-aanns/aanns_indef_all.csv")

aanns_in_babylm = [b['construction'] for b in babylm_aanns]

In [7]:
len(mahowald_aanns)

12960

In [8]:
# store mahowald aanns where the aans do not occur in babylm
mahowald_unseen_aanns = []
for aann in mahowald_aanns:
    if aann["construction"] not in aanns_in_babylm:
        mahowald_unseen_aanns.append(aann)
    else:
        print(aann["construction"])

print(f"\nOriginal: {len(mahowald_aanns)}; Unseen: {len(mahowald_unseen_aanns)}")

a mere three years
a mere twenty years
an extra five days
a mere three years
a mere twenty years
an extra five days
a mere three years
a mere twenty years
an extra five days

Original: 12960; Unseen: 12951


In [9]:
mahowald_final = []
for aann in mahowald_unseen_aanns:
    if aann['ADJ'].lower() in word_counts and aann['NOUN'].lower() in word_counts and aann['NUMERAL'].lower() in word_counts:
        mahowald_final.append(aann)
    else:
        print(aann['construction'])


In [10]:
utils.write_dict_list_to_csv(mahowald_final, "../data/mahowald-aann/mahowald-aanns-unseen_good.csv")