diff --git a/data.py b/data.py index cd225ce..24b4e33 100644 --- a/data.py +++ b/data.py @@ -442,6 +442,18 @@ def prepare_dataset(dataset, out, vocab_size, **params): train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:] train.data, test.data = all_data.data[:split_index], all_data.data[split_index:] train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:] + elif dataset == "RCV1": + print("Preparing data...") + all_data = TextRCV1() + all_data.preprocess(out=out, vocab_size=vocab_size, **params) + + # Split train/test set + train = copy.deepcopy(all_data) + test = copy.deepcopy(all_data) + split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper + train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:] + train.data, test.data = all_data.data[:split_index], all_data.data[split_index:] + train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:] elif dataset == "RCV1-Vectors-Original": assert out == "tfidf" assert vocab_size == None @@ -467,18 +479,6 @@ def prepare_dataset(dataset, out, vocab_size, **params): split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper train.data, test.data = all_data.data[:split_index], all_data.data[split_index:] train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:] - elif dataset == "RCV1": - print("Preparing data...") - all_data = TextRCV1() - all_data.preprocess(out=out, vocab_size=vocab_size, **params) - - # Split train/test set - train = copy.deepcopy(all_data) - test = copy.deepcopy(all_data) - split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper - train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:] - train.data, test.data = all_data.data[:split_index], all_data.data[split_index:] - train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:] return train, test