Skip to content

Commit

Permalink
Refactor split_train_and_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hrayrhar committed Mar 26, 2018
1 parent 253299a commit 406cd51
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions scripts/split_train_and_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,7 @@
import argparse


parser = argparse.ArgumentParser(description='Split data into train and test sets.')
parser.add_argument('subjects_root_path', type=str, help='Directory containing subject sub-directories.')
args, _ = parser.parse_known_args()

testset = set()
with open("resources/testset.csv", "r") as test_set_file:
for line in test_set_file:
x, y = line.split(',')
if int(y) == 1:
testset.add(x)

def move_to_partition(patients, partition):
def move_to_partition(args, patients, partition):
if not os.path.exists(os.path.join(args.subjects_root_path, partition)):
os.mkdir(os.path.join(args.subjects_root_path, partition))
for patient in patients:
Expand All @@ -23,12 +12,24 @@ def move_to_partition(patients, partition):
shutil.move(src, dest)


folders = os.listdir(args.subjects_root_path)
folders = list((filter(str.isdigit, folders)))
train_patients = [x for x in folders if not x in testset]
test_patients = [x for x in folders if x in testset]
def main():
parser = argparse.ArgumentParser(description='Split data into train and test sets.')
parser.add_argument('subjects_root_path', type=str, help='Directory containing subject sub-directories.')
args, _ = parser.parse_known_args()

assert len(set(train_patients) & set(test_patients)) == 0
test_set = set()
with open("resources/testset.csv", "r") as test_set_file:
for line in test_set_file:
x, y = line.split(',')
if int(y) == 1:
test_set.add(x)

move_to_partition(train_patients, "train")
move_to_partition(test_patients, "test")
folders = os.listdir(args.subjects_root_path)
folders = list((filter(str.isdigit, folders)))
train_patients = [x for x in folders if x not in test_set]
test_patients = [x for x in folders if x in test_set]

assert len(set(train_patients) & set(test_patients)) == 0

move_to_partition(args, train_patients, "train")
move_to_partition(args, test_patients, "test")

0 comments on commit 406cd51

Please sign in to comment.