# Overview

This notebook shows how to split the data into a training and testing set.

In order to run this notebook, you first need to produce the file `experimental_data.csv`, e.g. by running the notebook `Extract_experimental_data.csv`

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# Load experimental_data.csv
df = pd.read_csv('experimental_data.csv')

# Data Split into Test and Training

In [None]:
#Split all data to traing and testing with 0 as the divide (this catches a little bit of the downcurve on the graph)
test_set_1 = df[df['z_target (m)'] < 0]
training_set_1 = df[df['z_target (m)'] >= 0]

#split thee training and test data into x and y
z_training_set_1 = 1.e6*training_set_1['z_target (m)']
TOD_training_set_1 = (1.e15)**3*training_set_1['TOD (s^3)']
protons_training_set_1 = training_set_1['n_protons (1/sr)']

z_test_set_1 = 1.e6*test_set_1['z_target (m)']
TOD_test_set_1 = (1.e15)**3*test_set_1['TOD (s^3)']
protons_test_set_1 = test_set_1['n_protons (1/sr)']

In [None]:
test_set_2 = df[df['z_target (m)'] <= 40e-6]
training_set_2 = df[df['z_target (m)'] > 40e-6]

z_training_set_2 = 1.e6*training_set_2['z_target (m)']
TOD_training_set_2 = (1.e15)**3*training_set_2['TOD (s^3)']
protons_training_set_2 = training_set_2['n_protons (1/sr)']

z_test_set_2 = 1.e6*test_set_2['z_target (m)']
TOD_test_set_2 = (1.e15)**3*test_set_2['TOD (s^3)']
protons_test_set_2 = test_set_2['n_protons (1/sr)']

In [None]:
plt.clf()
ax = plt.figure().add_subplot(projection='3d')

ax.scatter( TOD_training_set_1, z_training_set_1,protons_training_set_1, c='r',alpha=0.3)
ax.scatter( TOD_test_set_1, z_test_set_1,protons_test_set_1, c='b', alpha=0.3)
ax.view_init(elev=40., azim=40, roll=0)
plt.xlabel('TOD')
plt.ylabel('z_target')
#plt.zlabel('number of protons')

# Visualizing Data Split

In [None]:
#plot training set 1
plt.scatter(z_training_set_1, protons_training_set_1, label='training data')
plt.scatter(1.e6*df['z_target (m)'],df['n_protons (1/sr)'], s=50, facecolors='none', edgecolors='r', label='Expt data')
plt.title("Z_target Training data set 1")
plt.xlabel('z_target (m)')
plt.ylabel('Number of protons (1/sr)')
plt.legend()
plt.savefig("training_set_1.png")

In [None]:
#plot test set 1
plt.scatter(z_test_set_1, protons_test_set_1, label='test data')
plt.scatter(1.e6*df['z_target (m)'],df['n_protons (1/sr)'], s=50, facecolors='none', edgecolors='r', label='expt data')
plt.title("Z_target Test data set 1")
plt.xlabel('z_target (m)')
plt.ylabel('Number of protons (1/sr)')
plt.legend()
plt.savefig("test_set_1.png")

In [None]:
plt.scatter(z_training_set_2, protons_training_set_2, label='training data')
plt.scatter(1.e6*df['z_target (m)'],df['n_protons (1/sr)'], s=50, facecolors='none', edgecolors='r', label='expt data')
plt.title("Z_target Training data set 2")
plt.xlabel('z_target (m)')
plt.ylabel('Number of protons (1/sr)')
plt.legend()
plt.savefig("training_set_2.png")

In [None]:
plt.scatter(z_test_set_2, protons_test_set_2, label='test data')
plt.scatter(1.e6*df['z_target (m)'],df['n_protons (1/sr)'], s=50, facecolors='none', edgecolors='r', label='expt data')
plt.title("Z_target Test data set 2")
plt.xlabel('z_target (m)')
plt.ylabel('Number of protons (1/sr)')
plt.legend()
plt.savefig("test_set_2.png")

# Saving Split Data to CSV files

In [None]:
#save training x and y into csv
training_set_1_df = pd.DataFrame({
    'z_target (m)': z_training_set_1,
    'TOD (s^3)' : TOD_training_set_1,
    'n_protons (1/sr)': protons_training_set_1
})

training_set_1_df.to_csv('training_set_1.csv')

In [None]:
#save test x and y to csv file
test_set_1_df = pd.DataFrame( {
    'z_target (m)': z_test_set_1,
    'TOD (s^3)': TOD_test_set_1,
    'n_protons (1/sr)': protons_test_set_1} )

test_set_1_df.to_csv('test_set_1.csv')

In [None]:
training_set_2_df = pd.DataFrame({
    'z_target (m)': z_training_set_2,
    'TOD (s^3)': TOD_training_set_2,
    'n_protons (1/sr)': protons_training_set_2
})

training_set_2_df.to_csv('training_set_2.csv')

In [None]:
test_set_2_df = pd.DataFrame({
    'z_target (m)': z_test_set_2,
    'TOD (s^3)': TOD_test_set_2,
    'n_protons (1/sr)': protons_test_set_2
})

test_set_2_df.to_csv('test_set_2.csv')