Skip to content

Commit

Permalink
Moved RCV1 elif Block
Browse files Browse the repository at this point in the history
  • Loading branch information
SuyashLakhotia committed Mar 20, 2018
1 parent 4550389 commit e2cdd42
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions data.py
Expand Up @@ -442,6 +442,18 @@ def prepare_dataset(dataset, out, vocab_size, **params):
train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:]
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
elif dataset == "RCV1":
print("Preparing data...")
all_data = TextRCV1()
all_data.preprocess(out=out, vocab_size=vocab_size, **params)

# Split train/test set
train = copy.deepcopy(all_data)
test = copy.deepcopy(all_data)
split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper
train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:]
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
elif dataset == "RCV1-Vectors-Original":
assert out == "tfidf"
assert vocab_size == None
Expand All @@ -467,18 +479,6 @@ def prepare_dataset(dataset, out, vocab_size, **params):
split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
elif dataset == "RCV1":
print("Preparing data...")
all_data = TextRCV1()
all_data.preprocess(out=out, vocab_size=vocab_size, **params)

# Split train/test set
train = copy.deepcopy(all_data)
test = copy.deepcopy(all_data)
split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper
train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:]
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]

return train, test

Expand Down

0 comments on commit e2cdd42

Please sign in to comment.