In [1]:
import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace

In [2]:
args = Namespace(
    raw_datacet_csv="../data/surnames/surnames.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="../data/surnames/surname_with_splits.csv",
    seed=1337
)

In [3]:
surnames = pd.read_csv(args.raw_datacet_csv, header=0)

In [4]:
surnames.head()

Unnamed: 0,surname,nationality
0,Woodford,English
1,Coté,French
2,Kore,English
3,Koury,Arabic
4,Lebzak,Russian


In [5]:
surnames.nationality.unique()

array(['English', 'French', 'Arabic', 'Russian', 'Japanese', 'Chinese',
       'Italian', 'Czech', 'Irish', 'German', 'Greek', 'Spanish',
       'Polish', 'Dutch', 'Vietnamese', 'Korean', 'Portuguese',
       'Scottish'], dtype=object)

In [6]:
by_nationality = collections.defaultdict(list)
for i, row in surnames.iterrows():
    by_nationality[row.nationality].append(row.to_dict())

In [8]:
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_nationality.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'
    final_list.extend(item_list)
final_surnames = pd.DataFrame(final_list)

In [10]:
final_surnames.head()

Unnamed: 0,surname,nationality,split
0,Totah,Arabic,train
1,Abboud,Arabic,train
2,Fakhoury,Arabic,train
3,Srour,Arabic,train
4,Sayegh,Arabic,train


In [12]:
final_surnames.groupby('nationality')['split'].value_counts()

nationality  split
Arabic       train    1122
             test      241
             val       240
Chinese      train     154
             test       33
             val        33
Czech        train     289
             test       63
             val        62
Dutch        train     165
             test       36
             val        35
English      train    2080
             test      447
             val       445
French       train     160
             test       35
             val        34
German       train     403
             test       87
             val        86
Greek        train     109
             test       24
             val        23
Irish        train     128
             test       28
             val        27
Italian      train     420
             test       90
             val        90
Japanese     train     542
             test      117
             val       116
Korean       train      53
             test       13
             val        11
Polish   

In [13]:
final_surnames.to_csv(args.output_munged_csv, index=False)