## Prepare dataset:

It processes our dataset like this:

1) obtain the held-out set (HOS) which includes 1515 data samples whose positions are {-12.502, -29.5, and -41.9};

2) split the rest (model-learning set; MLS) into training set (80% of 5636 data samples) and validation set(20% of 5636 data samples);

3) after running, you will get a new folder named **dnn_dataset** which includes:

    -hos.csv: the HOS set (1515 data samples)
    
    -training.csv: the training set (80% of 5636 data samples of MLS)
    
    -validation.csv: for validation set (20% of 5636 data samples of MLS)
    

First, let's import packages.

In [1]:
#------improt packages------#
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt

In [2]:
#----- read the extracted csv data into a dataframe
df = pd.read_csv('processed_csv/processed_combined.csv')

In [3]:
#----- get the unique positions we have in the dataset
y = df['y'].unique()
for item in y:
    cur_df = df[df['y']==item]
    print('>>>y={}, num={}'.format(item, len(cur_df)))
print('The total samples we have are {}'.format(len(df)))


>>>y=0.0, num=924
>>>y=-3.969, num=500
>>>y=-9.988, num=613
>>>y=-12.502, num=395
>>>y=-17.992, num=357
>>>y=-19.7, num=376
>>>y=-21.034000000000002, num=747
>>>y=-24.076999999999998, num=567
>>>y=-29.5, num=634
>>>y=-36.116, num=560
>>>y=-39.4, num=386
>>>y=-41.01, num=606
>>>y=-41.9, num=486
The total samples we have are 7151


The function **get_hos_mls** will split the dataset (7151 data samples) into:

1) held-out set (1515 data samples);

2) model-learning set: training set (80% of 5636 data samples) and validation set (20% of 5636 data samples).

In [4]:
#------ held_out_positions: the positions we want to put into HOS
#------ input_file: the extracted csv file
#------ hos_file: the file name of HOS 
#------ training_file: the file name of training set
#------ validation_file: the file name of validation set

def get_hos_mls(held_out_positions, input_file, hos_file, training_file, validation_file):

    org_combined_df = pd.read_csv(input_file)
    unique_y = org_combined_df['y'].unique()
    
    
    #-First, divide the whole dataset into Held-out set (HOS) and Model-learning set (MLS)
    
    hos_df = org_combined_df[org_combined_df['y'].isin(held_out_positions)]
    mls_df = pd.concat([org_combined_df,hos_df]).drop_duplicates(keep=False)
    
    
    #-Second, split Model-learning set (MLS) into training set and validation set
    #-To ensure the training set and validation have the same distribution, we do stratify splitting.
    training_df, validation_df = train_test_split(mls_df, test_size = 0.2, random_state = 42,stratify=mls_df['y'])
    
    #-Save our dataset
    hos_df.to_csv(hos_file, index=False)
    training_df.to_csv(training_file, index=False)
    validation_df.to_csv(validation_file, index=False)
    
    print('We have {} data samples in HOS'.format(len(hos_df)))
    print('We have {} data samples in Training set'.format(len(training_df)))
    print('We have {} data samples in Validation set'.format(len(validation_df)))
    print('')
    print('')
    print('>>> Congrats! Datasets have been saved successfully!')
      

Below, we start to run our code.

In [7]:
#------ Start to run ------#
if __name__ == '__main__':
    
    # create folder if not exists
    if not os.path.exists('dnn_dataset'):
        os.makedirs('dnn_dataset')
    else:
        print('There is no need to create it!')
        
    
    #-set up parameters
    input_file = 'processed_csv/processed_combined.csv'
    hos_file = 'dnn_dataset/hos.csv'
    training_file = 'dnn_dataset/training.csv'
    validation_file = 'dnn_dataset/validation.csv'
    held_out_positions = [-12.502, -29.5, -41.9]
    
    get_hos_mls(held_out_positions, input_file, hos_file, training_file, validation_file)

There is no need to create it!
We have 1515 data samples in HOS
We have 4508 data samples in Training set
We have 1128 data samples in Validation set


>>> Congrats! Datasets have been saved successfully!
