# Tidy source files to keys csv

In [1]:
%reload_ext lab_black
import pandas as pd
import numpy as np
import gzip
from IPython.display import clear_output

# Read training file
train_file = "issues/preprocessing/6ktraining_v2.dict"

strain_file = "issues/preprocessing/strain.txt"
strain_key_file = "issues/preprocessing/strain_key.txt"

grain_file = "issues/preprocessing/grain_nws.dict"
grain_key_file = "issues/preprocessing/grain_key.txt"

# Imageability
cortese = pd.read_csv(
    "issues/preprocessing/cortese2004norms.csv", skiprows=9, na_filter=False
)
img_map = cortese[["item", "rating"]]
img_map.columns = ["word", "img"]

# Zeno norm
zeno = pd.read_csv("issues/preprocessing/EWFG.csv", na_values="", keep_default_na=False)
zeno["gr14"] = pd.to_numeric(zeno.f, errors="coerce")  # Stage 14 is adult frequency
[zeno.pop(v) for v in ["sfi", "d", "u", "f"]]
clear_output()

# Chang training set (from Chang 2019 github)
y_wordnet = np.genfromtxt("issues/preprocessing/wordNet_6229.csv", delimiter=",")

wordnet_dict = pd.read_csv(
    "issues/preprocessing/6kdict",
    sep="\t",
    header=None,
    names=["word", "ort", "pho", "wf"],
    na_filter=False,  # Bug fix: incorrectly treated null as missing value in the corpus
)

## Copy index for key
wordnet_dict["wn_idx"] = wordnet_dict.index

## Drop wordnet duplicates (Do not drop)
## HS04: There were 39 words in which a single spelling was associated with two or more meanings (mainly words such as SHEEP, FISH, or HIT, whose plural or past tense morphological inflection involves no change from the stem).
# wordnet_dict.drop_duplicates(subset=['word'], inplace=True)

In [None]:
# Merge Zeno and IMG into train
train = pd.read_csv(
    train_file,
    sep="\t",
    header=None,
    names=["word", "ort", "pho", "wf"],
    na_filter=False,  # Bug fix: incorrectly treated null as missing value in the corpus
)

train = pd.merge(train, zeno, on="word", how="left", validate="1:1")

# Assume Zeno missing = 0
for x in range(14):
    variable_name = 'gr' + str(x+1)
    train[variable_name] = train[variable_name].map(lambda x: 0 if np.isnan(x) else x)

train = pd.merge(train, img_map, on="word", how="left", validate="1:1")


# Merge Chang
wnid = wordnet_dict.loc[:,['word', 'wn_idx']]
train = train.merge(wnid, how='inner', on='word')
print(f'Words in training set: {len(train)}')


In [None]:
strain = pd.read_csv(
    strain_file, sep="\t", header=None, names=["word", "ort", "pho", "wf"]
)

strain_key = pd.read_table(
    strain_key_file,
    header=None,
    delim_whitespace=True,
    names=["word", "frequency", "pho_consistency", "imageability"],
)

strain = pd.merge(strain, strain_key)
strain = pd.merge(strain, img_map, on="word", how="left")
strain.sample(5)

In [None]:
strain.groupby("frequency").mean()

In [None]:
grain = pd.read_csv(
    grain_file,
    sep='\t',
    header=None,
    names=['word', 'ort', 'pho_large', 'pho_small']
)

grain_key = pd.read_table(
    grain_key_file,
    header=None,
    delim_whitespace=True,
    names=['word', 'condition']
)

grain_key['condition'] = np.where(
    grain_key['condition'] == 'critical', 'ambiguous', 'unambiguous'
)

grain = pd.merge(grain, grain_key)

grain['img'] = 0
grain['wf'] = 0
grain.sample(5)

In [2]:
taraban = pd.read_csv("issues/preprocessing/taraban.csv")
taraban.columns = ["id", "cond", "word", "ort", "pho", "wf"]
taraban = pd.merge(taraban, img_map, on="word", how="left")
taraban.sample(5)

Unnamed: 0,id,cond,word,ort,pho,wf,img
125,125,Regular control for High-frequency regular-inc...,draw,__draw________,_dra______,1538,5.0
173,173,Regular control for Low-frequency regular-inco...,dime,___di_me______,__dAm_____,209,6.8
11,11,High-frequency exception,have,___ha_ve______,__h@v_____,165137,1.9
37,37,High-frequency regular-inconsistent,grow,__grow________,_gro______,3372,3.9
59,59,Low-frequency exception,plow,__plow________,_plW______,90,5.6


In [3]:
glushko = pd.read_csv("issues/preprocessing/glushko_nonword.csv")
glushko.columns = ["id", "cond", "word", "pho", "ort"]

glushko["img"] = 0
glushko["wf"] = 0
glushko.sample(5)

Unnamed: 0,id,cond,word,pho,ort,img,wf
56,56,Regular,soad,"['__sod_____', '__sad_____']",___soad_______,0,0
49,49,Regular,pold,['__pold____'],___po_ld______,0,0
46,46,Regular,sheed,['__Sid_____'],__sheed_______,0,0
67,67,Regular,wote,['__wot_____'],___wo_te______,0,0
69,69,Regular,prain,['_pren_____'],__prain_______,0,0


### Check raw data integrity

In [None]:
# Check all represtation follow 14 ort, 10 pho format
assert all([len(x) == 14 for x in train.ort])
assert all([len(x) == 14 for x in strain.ort])
assert all([len(x) == 14 for x in grain.ort])
assert all([len(x) == 14 for x in taraban.ort])
assert all([len(x) == 14 for x in glushko.ort])

assert all([len(x) == 10 for x in train.pho])
assert all([len(x) == 10 for x in strain.pho])
assert all([len(x) == 10 for x in grain.pho_small])
assert all([len(x) == 10 for x in grain.pho_large])
assert all([len(x) == 10 for x in taraban.pho])

from ast import literal_eval
for pho in glushko.pho:
    ps = literal_eval(pho)
    for p in ps:
        assert len(p) == 10

# Check all fufill trim_ort criteria
locs = [0, 11, 12, 13]

for l in locs:
    assert all([x == '_' for x in train.ort.str.get(l)])
    assert all([x == '_' for x in strain.ort.str.get(l)])
    assert all([x == '_' for x in grain.ort.str.get(l)])
    assert all([x == '_' for x in taraban.ort.str.get(l)])
    assert all([x == '_' for x in glushko.ort.str.get(l)])

# No missing data in critical variables
assert sum(train.ort.isna()) == 0
assert sum(train.pho.isna()) == 0
assert sum(train.wf.isna()) == 0

assert sum(strain.ort.isna()) == 0
assert sum(strain.pho.isna()) == 0
assert sum(strain.wf.isna()) == 0

assert sum(grain.ort.isna()) == 0
assert sum(grain.pho_small.isna()) == 0
assert sum(grain.pho_large.isna()) == 0

assert sum(taraban.ort.isna()) == 0
assert sum(taraban.pho.isna()) == 0

assert sum(glushko.ort.isna()) == 0
assert sum(glushko.pho.isna()) == 0

In [None]:
def trim_ort(t):
    # The first bit and last 3 bits are empty in this source dataset (6ktraining.dict)
    t['ort'] = t.ort.apply(lambda x: x[1:11])
    return t


df_train = trim_ort(train)
df_strain = trim_ort(strain)
df_grain = trim_ort(grain)
df_taraban = trim_ort(taraban)
df_glushko = trim_ort(glushko)

# Imageability missing data replacement

In [None]:
def chk_missing(df, var):
    print(
        '{} missing in {}: {}/{}'.format(
            var, df, sum(globals()[df][var].isna()), len((globals()[df]))
        )
    )

chk_missing('df_train', 'img')
chk_missing('df_strain', 'img')
chk_missing('df_grain', 'img')
chk_missing('df_taraban', 'img')
chk_missing('df_glushko', 'img')

In [None]:
# Fill missing value to mean img rating
mean_img = df_train.img.mean()
df_train['img'] = df_train.img.fillna(mean_img)

# Fill missing value to condition mean img rating
mean_strain_hi_img = df_strain.loc[df_strain.imageability == "HI", 'img'].mean()
mean_strain_lo_img = df_strain.loc[df_strain.imageability == "LI", 'img'].mean()

df_strain.loc[df_strain.imageability == "HI",
              "img"] = df_strain.loc[df_strain.imageability == "HI",
                                     "img"].fillna(mean_strain_hi_img)

df_strain.loc[df_strain.imageability == "LI",
              "img"] = df_strain.loc[df_strain.imageability == "LI",
                                     "img"].fillna(mean_strain_lo_img)

# Since taraban do not maniputate img, just replace by training set mean
df_taraban['img'] = df_taraban.img.fillna(mean_img)

# Handle homograph

From Jay (201217)
I'm guessing they just split the real frequency in two.  If it's possible to check that, even approximately, that would be good.  (If we don't have the WJ frequencies independently of these training sets, we could just ball park it--do the frequencies in the file look comparable, or half of, the frequencies of words that are more-or-less the same frequency in some other norms.  (I'm not sure if that's clear.  If not, we can talk about it.)

- Split all frequency into n_dup

In [5]:
# Build a dictionary for looking up word:n_dup

tmp_count = train.groupby('word').agg('count').reset_index().loc[:,['word','wn_idx']]
tmp_dups = tmp_count.loc[tmp_count.wn_idx>1,]
dups_dict = dict(zip(tmp_dups.word, tmp_dups.wn_idx))



NameError: name 'train' is not defined

In [None]:
# Check dups in testset
strain.loc[strain.word.isin(dups_dict.keys()),]


In [4]:
taraban.loc[taraban.word.isin(dups_dict.keys()),]

NameError: name 'dups_dict' is not defined

In [None]:
def adjust_wf(row):
    if row.word in dups_dict.keys():
        return(row.wf / dups_dict[row.word])
    else:
        return(row.wf)
        
train['wf'] = train.apply(adjust_wf, axis=1)

In [None]:
df_train.to_csv('dataset/df_train.csv')
df_strain.to_csv('dataset/df_strain.csv')
df_grain.to_csv('dataset/df_grain.csv')
df_taraban.to_csv('dataset/df_taraban.csv')
df_glushko.to_csv('dataset/df_glushko.csv')

# Save semantics

In [None]:
# Export wordnet semantic representation (n=5821) for training set
sem_train = y_wordnet[train.wn_idx,]
print(f'Shape of selected semantic representation: {sem_train.shape}')
np.savez_compressed('dataset/sem_train.npz', data=sem_train)

In [None]:
# Export Strain semantic
strain_word_idx = [df_train.loc[df_train.word==w,].index[0] for w in strain.word]
sem_strain = sem_train[strain_word_idx,]
np.savez_compressed('dataset/sem_strain.npz', data=sem_strain)


# Encode input and output

In [None]:
# Encode orthographic representation
def ort2bin(o_col, trimMode=True, verbose=True):
    # Replicating support.py (o_char)
    # This function wrap tokenizer.texts_to_matrix to fit on multiple
    # independent slot-based input
    # i.e. one-hot encoding per each slot with independent dictionary

    from tensorflow.keras.preprocessing.text import Tokenizer

    nSlot = len(o_col[0])
    nWord = len(o_col)

    slotData = nWord * [None]
    binData = pd.DataFrame()

    for slotId in range(nSlot):
        for wordId in range(nWord):
            slotData[wordId] = o_col[wordId][slotId]

        t = Tokenizer(filters='', lower=False)
        t.fit_on_texts(slotData)
        seqData = t.texts_to_sequences(
            slotData
        )  # Maybe just use sequence data later

        # Triming first bit in each slot
        if trimMode == True:
            tmp = t.texts_to_matrix(slotData)
            thisSlotBinData = tmp[:, 1::
                                 ]  # Remove the first bit which indicate a separate slot (probably useful in recurrent network)
        elif trimMode == False:
            thisSlotBinData = t.texts_to_matrix(slotData)

        # Print dictionary details
        if verbose == True:
            print(
                'Slot {} (n = {}, unique token = {}) {} \n'.format(
                    slotId, t.document_count, len(t.word_index.items()),
                    t.word_docs
                )
            )

        # Put binary data into a dataframe
        binData = pd.concat(
            [binData, pd.DataFrame(thisSlotBinData)], axis=1, ignore_index=True
        )
        
    return binData

def ort2bin_v2(o_col):
    # Use tokenizer instead to acheive same thing, but with extra zeros columns
    # Will be useful for letter level recurrent model
    from tensorflow.keras.preprocessing.text import Tokenizer
    t = Tokenizer(filters='', lower=False, char_level=True)
    t.fit_on_texts(o_col)
    print('dictionary:', t.word_index)
    return t.texts_to_matrix(o_col)


# Merge all 3 ortho representation
all_word = pd.concat(
    [
        df_train.word, df_strain.word, df_grain.word, df_taraban.word,
        df_glushko.word
    ],
    ignore_index=True
)

all_ort = pd.concat(
    [df_train.ort, df_strain.ort, df_grain.ort, df_taraban.ort, df_glushko.ort],
    ignore_index=True
)

# Encoding orthographic representation
all_ort_bin = ort2bin(all_ort, verbose=True)

In [None]:
splitId_strain = len(df_train)
splitId_grain = splitId_strain + len(df_strain)
splitId_taraban = splitId_grain + len(df_grain)
splitId_glushko = splitId_taraban + len(df_taraban)

ort_train = np.array(all_ort_bin[0:splitId_strain])
ort_strain = np.array(all_ort_bin[splitId_strain:splitId_grain])
ort_grain = np.array(all_ort_bin[splitId_grain:splitId_taraban])
ort_taraban = np.array(all_ort_bin[splitId_taraban:splitId_glushko])
ort_glushko = np.array(all_ort_bin[splitId_glushko::])

# Save to disk
np.savez_compressed('dataset/ort_train.npz', data=ort_train)
np.savez_compressed('dataset/ort_strain.npz', data=ort_strain)
np.savez_compressed('dataset/ort_grain.npz', data=ort_grain)
np.savez_compressed('dataset/ort_taraban.npz', data=ort_taraban)
np.savez_compressed('dataset/ort_glushko.npz', data=ort_glushko)

print('==========Orthographic representation==========')
print('all shape:', all_ort_bin.shape)
print('ort_train shape:', ort_train.shape)
print('ort_strain shape:', ort_strain.shape)
print('ort_grain shape:', ort_grain.shape)
print('ort_taraban shape:', ort_taraban.shape)
print('ort_glushko shape:', ort_glushko.shape)

In [None]:
def pho2bin_v2(p_col, p_key):
    # Vectorize for performance (that no one ask for... )
    binLength = len(p_key['_'])
    nPhoChar = len(p_col[0])

    p_output = np.empty([len(p_col), binLength * nPhoChar])

    for slot in range(len(p_col[0])):
        slotSeries = p_col.str.slice(start=slot, stop=slot + 1)
        out = slotSeries.map(p_key).to_list()
        p_output[:, range(slot * 25, (slot + 1) * 25)] = out
    return p_output


from src.data_wrangling import gen_pkey
phon_key = gen_pkey()
pho_train = pho2bin_v2(train.pho, phon_key)
pho_strain = pho2bin_v2(strain.pho, phon_key)
pho_large_grain = pho2bin_v2(grain.pho_large, phon_key)
pho_small_grain = pho2bin_v2(grain.pho_small, phon_key)
pho_taraban = pho2bin_v2(taraban.pho, phon_key)

# Save to disk
np.savez_compressed('dataset/pho_train.npz', data=pho_train)
np.savez_compressed('dataset/pho_strain.npz', data=pho_strain)
np.savez_compressed('dataset/pho_large_grain.npz', data=pho_large_grain)
np.savez_compressed('dataset/pho_small_grain.npz', data=pho_small_grain)
np.savez_compressed('dataset/pho_taraban.npz', data=pho_taraban)

print('\n==========Phonological representation==========')
print(len(phon_key), ' phonemes: ', phon_key.keys())
print('pho_train shape:', pho_train.shape)
print('pho_strain shape:', pho_strain.shape)
print('pho_large_grain shape:', pho_large_grain.shape)
print('pho_small_grain shape:', pho_small_grain.shape)
print('pho_taraban shape:', pho_taraban.shape)

### Decoding check

In [None]:
from src.evaluate import get_all_pronunciations_fast as gapf
assert all(gapf(pho_train, phon_key) == df_train.pho)
assert all(gapf(pho_strain, phon_key) == df_strain.pho)
assert all(gapf(pho_large_grain, phon_key) == df_grain.pho_large)
assert all(gapf(pho_small_grain, phon_key) == df_grain.pho_small)
assert all(gapf(pho_taraban, phon_key) == df_taraban.pho)

## Special format for Glushko PHO (due to multiple correct answer with different length)

In [None]:
import ast, pickle

# Glushko pho dictionary
phonology_glushko = {
    x: ast.literal_eval(df_glushko.loc[i, 'pho'])
    for i, x in enumerate(df_glushko.word)
}

# Glushko one-hot encoded output dictionary
pho_glushko = {}
for k, v in phonology_glushko.items():
    ys = []
    for pho in v:
        y = []
        for char in pho:
            y += phon_key[char]
        ys.append(y)
    pho_glushko[k] = ys

with open('dataset/pho_glushko.pkl', 'wb') as f:
    pickle.dump(pho_glushko, f)

print('y_glushko dimension: {}'.format(len(pho_glushko['beed'][0])))

# Testing and evaluating new sampling probability

In [None]:
import pandas as pd
import numpy as np
df_train = pd.read_csv('dataset/df_train.csv', index_col=0)

# Plot sampling conversion graph
import matplotlib.pyplot as plt
import data_wrangling

plot_f = df_train.sort_values('wf')

fig, ax = plt.subplots(facecolor="w")
line1, = ax.plot(plot_f.wf, data_wrangling.get_sampling_probability(plot_f, "log"), label='Log')
line2, = ax.plot(plot_f.wf, data_wrangling.get_sampling_probability(plot_f, "hs04"), label='HS04')
line3, = ax.plot(plot_f.wf, data_wrangling.get_sampling_probability(plot_f, "jay"), label='JAY')

ax.legend(loc='lower right')
plt.xlabel('Word frequency')
# plt.xlim((0, 200))
# plt.ylim((0, .0006))
plt.ylabel('Sampling probability')
# plt.xlim([0,100])
plt.title('Tested sampling p vs. word frequency')
plt.show()

# Create new dictionary style representation (super fast hash table)

In [None]:
word_representation_mapping = dict(zip(train.word, train.index))

ort, pho, sem = {}, {}, {}

for word in df_train.word:
    word_idx = word_representation_mapping[word]
    ort[word] = ort_train[word_idx]
    pho[word] = pho_train[word_idx]
    sem[word] = sem_train[word_idx]

representation = {'ort':ort, 'pho':pho, 'sem':sem}
print(f'Total no. of training items: {len(ort.keys())}')


In [None]:
# Usage: representation["ort" or "pho" or "sem"]["word"]
print(f'Representations of "cat" are:\n')
print(f'ort: {representation["ort"]["cat"]} \n with shape = {representation["ort"]["cat"].shape}\n')
print(f'pho: {representation["pho"]["cat"]} \n with shape = {representation["pho"]["cat"].shape}\n')
print(f'sem: {representation["sem"]["cat"]} \n with shape = {representation["sem"]["cat"].shape}\n')

In [None]:
# Export to pkl.gz
import pickle, gzip

with gzip.open('dataset/representation_dictionary.pkl.gz', 'wb') as f:
    pickle.dump(representation, f)

# New test set pickle format
a dictionary with 4 keys
- item: maybe word or nonword string, easy to human eye
- ort: orthgraphic representation
- pho: phonological representation
- sem: semantic representation

If all item within training set, use data_wrangling.MyData.create_testset_from_train_idx() to create dictionary. 

Otherwise, create it manually

In [3]:
import data_wrangling
import gzip, pickle
from importlib import reload

reload(data_wrangling)
data = data_wrangling.MyData()

In [None]:
# Strain (all)
strain_items = data.df_strain.word.unique()
strain_items_idx = list(data.df_train.loc[data.df_train.word.isin(strain_items)].index)
strain_dict = data.create_testset_from_train_idx(strain_items_idx)

# with gzip.open("dataset/testsets/strain.pkl.gz", "wb") as f:
#     pickle.dump(strain_dict, f)

In [None]:
# Strain by each condition

def make_strain_sub_testsets(df, f, c, i, save_file):

    words = df.loc[
        (df.frequency == f) & (df.pho_consistency == c) & (df.imageability == i),
        "word",
    ].unique()

    idx = list(data.df_train.loc[data.df_train.word.isin(words)].index)
    testset_dict = data.create_testset_from_train_idx(idx)
    with gzip.open(f"dataset/testsets/{save_file}.pkl.gz", "wb") as f:
        pickle.dump(testset_dict, f)


make_strain_sub_testsets(
    data.df_strain, f="HF", c="CON", i="HI", save_file="strain_hf_con_hi"
)
make_strain_sub_testsets(
    data.df_strain, f="HF", c="CON", i="LI", save_file="strain_hf_con_li"
)
make_strain_sub_testsets(
    data.df_strain, f="HF", c="INC", i="HI", save_file="strain_hf_inc_hi"
)
make_strain_sub_testsets(
    data.df_strain, f="HF", c="INC", i="LI", save_file="strain_hf_inc_li"
)
make_strain_sub_testsets(
    data.df_strain, f="LF", c="CON", i="HI", save_file="strain_lf_con_hi"
)
make_strain_sub_testsets(
    data.df_strain, f="LF", c="CON", i="LI", save_file="strain_lf_con_li"
)
make_strain_sub_testsets(
    data.df_strain, f="LF", c="INC", i="HI", save_file="strain_lf_inc_hi"
)
make_strain_sub_testsets(
    data.df_strain, f="LF", c="INC", i="LI", save_file="strain_lf_inc_li"
)

In [None]:
# Train (Removed all one ort, multi sem words)
train_count = data.df_train.groupby("word").count().reset_index()
word_with_dup = list(train_count.loc[train_count.pho > 1, "word"])
train_no_dup_idx = list(data.df_train.loc[~data.df_train.word.isin(word_with_dup)].index)
train_dict = data.create_testset_from_train_idx(train_no_dup_idx)
with gzip.open("dataset/testsets/train.pkl.gz", "wb") as f:
    pickle.dump(train_dict, f)



In [None]:
import numpy as np
np.unique(train_dict['item'])

In [None]:
# Grain ambiguous
import numpy as np

nw_amb_idx = list(data.df_grain.loc[data.df_grain.condition=="ambiguous"].index)

grain_ambiguous_dict = {
    "item": list(data.df_grain.word[nw_amb_idx]),
    "ort": data.ort_grain[nw_amb_idx],
    "pho_large_grain": data.pho_large_grain[nw_amb_idx],
    "pho_small_grain": data.pho_small_grain[nw_amb_idx],
    "sem": np.zeros((len(nw_amb_idx), 2446))
}

with gzip.open("dataset/testsets/grain_ambiguous.pkl.gz", "wb") as f:
    pickle.dump(grain_ambiguous_dict, f)


In [None]:
# Grain unambiguous
nw_un_idx = list(data.df_grain.loc[data.df_grain.condition=="unambiguous"].index)

grain_unambiguous_dict = {
    "item": list(data.df_grain.word[nw_un_idx]),
    "ort": data.ort_grain[nw_un_idx],
    "pho_large_grain": data.pho_large_grain[nw_un_idx],
    "pho_small_grain": data.pho_small_grain[nw_un_idx],
    "sem": np.zeros((len(nw_un_idx), 2446))
}

with gzip.open("dataset/testsets/grain_unambiguous.pkl.gz", "wb") as f:
    pickle.dump(grain_unambiguous_dict, f)

In [None]:
# Train img
img_map = cortese[["item", "rating"]]
img_map.columns = ["word", "img"]
img_map.img.median()

In [None]:
# Low imageability (median split by Cortese rating)
low_img_cortese_word = list(img_map.loc[img_map.img < 4, "word"])
low_img_cortese_idx = list(
    data.df_train.loc[data.df_train.word.isin(low_img_cortese_word)].index
)
testset_low_img_cortest = data.create_testset_from_train_idx(low_img_cortese_idx)
with gzip.open("dataset/testsets/train_cortese_low_img.pkl.gz", "wb") as f:
    pickle.dump(testset_low_img_cortest, f)

# Hi imageability (median split by Cortese rating)
hi_img_cortese_word = list(img_map.loc[img_map.img >= 4, "word"])
hi_img_cortese_idx = list(
    data.df_train.loc[data.df_train.word.isin(hi_img_cortese_word)].index
)
testset_hi_img_cortest = data.create_testset_from_train_idx(hi_img_cortese_idx)
with gzip.open("dataset/testsets/train_cortese_hi_img.pkl.gz", "wb") as f:
    pickle.dump(testset_hi_img_cortest, f)

In [9]:
# Taraban
taraban_name_map = {
    "High-frequency exception": "HF-EXC",
    "High-frequency regular-inconsistent": "HF-REG-INC",
    "Low-frequency exception": "LF-EXC",
    "Low-frequency regular-inconsistent": "LF-REG-INC",
    "Regular control for High-frequency exception": "CTRL-HF-EXC",
    "Regular control for High-frequency regular-inconsistent": "CTRL-HF-REG-INC",
    "Regular control for Low-frequency exception": "CTRL-LF-EXC",
    "Regular control for Low-frequency regular-inconsistent": "CTRL-LF-REG-INC",
}


for c in taraban.cond.unique():
    idx = list(
        data.df_train.loc[
            data.df_train.word.isin(taraban.loc[taraban.cond == c, "word"]),
        ].index
    )
    print(idx)

    with gzip.open(
        f"dataset/testsets/taraban_{taraban_name_map[c].lower()}.pkl.gz", "wb"
    ) as f:
        pickle.dump(data.create_testset_from_train_idx(idx), f)

[47, 376, 431, 703, 885, 1221, 1228, 1238, 1688, 1889, 2016, 2178, 3007, 3572, 3597, 3598, 3892, 3924, 4091, 5523, 5550, 5602, 5612, 5728, 5731]
[124, 353, 539, 600, 898, 1137, 1147, 1553, 1583, 1598, 1699, 1944, 1955, 2073, 2238, 2291, 2881, 3203, 3377, 3818, 3833, 3907, 4392, 4554]
[387, 465, 536, 1141, 1232, 1629, 2064, 2767, 3272, 3315, 3352, 3408, 3843, 4075, 4161, 4609, 4919, 4927, 5245, 5484, 5517, 5544, 5725, 5735]
[474, 895, 914, 943, 960, 1118, 1710, 1711, 2104, 2157, 2281, 2599, 2625, 2768, 2818, 2979, 2980, 2995, 3004, 3035, 3465, 3470, 3477, 3699, 5191, 5203]
[202, 215, 564, 764, 1121, 1189, 1428, 1972, 2067, 2254, 3183, 3202, 3363, 4043, 4504, 4791, 5048, 5583, 5621, 5626, 5676, 5707, 5770]
[81, 232, 547, 770, 920, 1281, 1372, 1459, 1485, 1535, 1546, 1547, 1860, 1883, 1956, 2223, 2247, 2913, 3325, 3381, 3752, 3809, 4066, 4281, 4883]
[153, 470, 535, 1158, 1255, 1431, 1620, 2000, 2790, 3287, 3359, 3578, 3782, 4307, 4351, 4875, 4974, 5340, 5506, 5556, 5591, 5689, 5705]
[462,