In [None]:
import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/CSC-project/PyTorchNLPBook/
!pip install -r requirements.txt

In [None]:
args = Namespace(
    raw_dataset_csv="data/ag_news/news.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="data/ag_news/news_with_splits.csv",
    seed=1337
)

In [None]:
# Read raw data
news = pd.read_csv(args.raw_dataset_csv, header=0)

In [None]:
news.head()

Unnamed: 0,category,title
0,Business,Wall St. Bears Claw Back Into the Black (Reuters)
1,Business,Carlyle Looks Toward Commercial Aerospace (Reu...
2,Business,Oil and Economy Cloud Stocks' Outlook (Reuters)
3,Business,Iraq Halts Oil Exports from Main Southern Pipe...
4,Business,"Oil prices soar to all-time record, posing new..."


In [None]:
# Unique classes
set(news.category)

{'Business', 'Sci/Tech', 'Sports', 'World'}

In [None]:
# Splitting train by category
# Create dict
by_category = collections.defaultdict(list)
for _, row in news.iterrows():
    by_category[row.category].append(row.to_dict())

In [None]:
# Create split data
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_category.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'  
    
    # Add to final list
    final_list.extend(item_list)

In [None]:
# Write split data to file
final_news = pd.DataFrame(final_list)

In [None]:
final_news.split.value_counts()

train    84000
test     18000
val      18000
Name: split, dtype: int64

In [None]:
# Preprocess the reviews
def preprocess_text(text):
    text = ' '.join(word.lower() for word in text.split(" "))
    text = re.sub(r"([.,!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    return text
    
final_news.title = final_news.title.apply(preprocess_text)

In [None]:
final_news.head()

Unnamed: 0,category,split,title
0,Business,train,"jobs , tax cuts key issues for bush"
1,Business,train,jarden buying mr . coffee s maker
2,Business,train,retail sales show festive fervour
3,Business,train,intervoice s customers come calling
4,Business,train,boeing expects air force contract


In [None]:
# Write munged data to CSV
final_news.to_csv(args.output_munged_csv, index=False)