In [None]:
!rm ag_news_csv.tgz
!wget https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
!tar -xvzf ag_news_csv.tgz

In [None]:
import pandas as pd
from tqdm import tqdm_notebook

df = {}
keys = ['train', 'test']
for k in keys :
    df[k] = pd.read_csv('ag_news_csv/' + k + '.csv', header=None)

classes = open('ag_news_csv/classes.txt').read().strip().split('\n')
for k in keys :
    df[k][0] = df[k][0].apply(lambda x : classes[x-1])

import pandas as pd
df_texts = []
df_labels = []
df_exp_splits = []

for key in ['train', 'test'] :
    df_texts += list(df[key][2])
    df_labels += list(df[key][0])
    df_exp_splits += [key] * len(list(df[key][2]))
    
df = pd.DataFrame({'text' : df_texts, 'label' : df_labels, 'exp_split' : df_exp_splits})

from sklearn.model_selection import train_test_split
train_idx, dev_idx = train_test_split(df.index[df.exp_split == 'train'], test_size=0.15, random_state=16377)
df.loc[dev_idx, 'exp_split'] = 'dev'
df.to_csv('agnews_dataset_split_full.csv', index=False)

In [None]:
import pandas as pd
import json
from tqdm import tqdm
import os
os.makedirs('data', exist_ok=True)

In [None]:
data = pd.read_csv('agnews_dataset_split_full.csv')

In [None]:
import spacy
nlp = spacy.load('en')

def clean_text(text) :
    return " ".join([sent.text for sent in nlp(text)])

In [None]:
data = data.rename(columns = {"label":"label_id"})

In [None]:
label_ids = {v:k for k,v in dict(enumerate(data.label_id.unique())).items()}

tqdm.pandas()

data["label"] = data.label_id.progress_apply(lambda x : label_ids[x])
data["text"] = data.text.progress_apply(lambda x : clean_text(x))

In [None]:
pd.options.mode.chained_assignment = None 
train = data.loc[data.exp_split == "train"]
train["annotation_id"] = list(range(len(train)))
train.annotation_id = train.annotation_id.apply(lambda x: "train_" + str(x))
train.to_csv("data/train.csv", index = False)

dev = data.loc[data.exp_split == "dev"]
dev["annotation_id"] = list(range(len(dev)))
dev.annotation_id = dev.annotation_id.apply(lambda x: "dev_" + str(x))
dev.to_csv("data/dev.csv", index = False)

test = data.loc[data.exp_split == "test"]
test["annotation_id"] = list(range(len(test)))
test.annotation_id = test.annotation_id.apply(lambda x: "test_" + str(x))
test.to_csv("data/test.csv", index = False)