In [None]:
import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace

In [None]:
args = Namespace(
    raw_dataset_csv="data_for_model.csv",
    train_proportion=0.6,
    val_proportion=0.2,
    test_proportion=0.2,
    output_munged_csv="data_with_splits.csv",
    seed=101
)

In [None]:
# Read raw data
comments = pd.read_csv(args.raw_dataset_csv, header=0)

In [None]:
comments

In [None]:
# Unique classes
set(comments['Kind of offensive language'])

In [None]:
# Splitting train by nationality
# Create dict
by_kind_language = collections.defaultdict(list)
for _, row in comments.iterrows():
    by_kind_language[row['Kind of offensive language']].append(row.to_dict())

In [None]:
# Create split data
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_kind_language.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'  
    
    # Add to final list
    final_list.extend(item_list)

In [None]:
# Write split data to file
final_comments = pd.DataFrame(final_list)

In [None]:
final_comments.groupby(['split', 'Kind of offensive language']).count()

In [None]:
# Write munged data to CSV
final_comments.to_csv(args.output_munged_csv, index=False)