In [8]:
import os
import numpy as np
import pandas as pd
import re
import collections

from argparse import Namespace

In [21]:
args = Namespace(
    raw_dataset_csv="../data/ag_news/news.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="../data/ag_news/news_with_splits.csv",
    seed=1337
)

In [10]:
news = pd.read_csv(args.raw_dataset_csv, header=0)

In [11]:
news.head()

Unnamed: 0,category,title
0,Business,Wall St. Bears Claw Back Into the Black (Reuters)
1,Business,Carlyle Looks Toward Commercial Aerospace (Reu...
2,Business,Oil and Economy Cloud Stocks' Outlook (Reuters)
3,Business,Iraq Halts Oil Exports from Main Southern Pipe...
4,Business,"Oil prices soar to all-time record, posing new..."


In [12]:
news.category.value_counts()

Business    30000
Sci/Tech    30000
Sports      30000
World       30000
Name: category, dtype: int64

In [13]:
by_category = collections.defaultdict(list)
for _, row in news.iterrows():
    by_category[row.category].append(row.to_dict())

In [14]:
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_category.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'  
    
    # Add to final list
    final_list.extend(item_list)
    
final_list = pd.DataFrame(final_list)

In [15]:
final_list.groupby('category')['split'].value_counts()

category  split
Business  train    21000
          test      4500
          val       4500
Sci/Tech  train    21000
          test      4500
          val       4500
Sports    train    21000
          test      4500
          val       4500
World     train    21000
          test      4500
          val       4500
Name: split, dtype: int64

In [17]:
def preprocess_text(text):
    text = ' '.join(word.lower() for word in text.split(" "))
    text = re.sub(r"([.,!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    return text
    
final_list.title = final_list.title.apply(preprocess_text)

In [19]:
final_list.head()

Unnamed: 0,category,title,split
0,Business,"jobs , tax cuts key issues for bush",train
1,Business,jarden buying mr . coffee s maker,train
2,Business,retail sales show festive fervour,train
3,Business,intervoice s customers come calling,train
4,Business,boeing expects air force contract,train


In [22]:
final_list.to_csv(args.output_munged_csv, index=False)