# Create a test.csv from train.csv

In [18]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import os
from pathlib import Path
from sklearn.model_selection import train_test_split

In [19]:
cwd_path = Path.cwd()
path = cwd_path.parent.joinpath("data")

# Load the train.csv file as "data_df"
data_df = pd.read_csv(os.path.join(path, r'unprocessed\train.csv'))

# Defining path_unbalanced
path_splitted = path.joinpath("splitted")

# Step 1: Perform a 95/5 Train-Test Split without any balancing 
train_df, valid_df = train_test_split(
    data_df, 
    test_size=0.05, random_state=42, shuffle=True)

# Step 2: Split data_df into ap and pa subsets
ap_df = data_df[data_df['AP/PA'] == 'AP']
pa_df = data_df[data_df['AP/PA'] == 'PA']

# Step 3: Split ap_df into train_ap and valid_ap
train_ap = ap_df[ap_df.index.isin(train_df.index)]
valid_ap = ap_df[ap_df.index.isin(valid_df.index)]

# Step 4: Split pa_df into train_pa and valid_pa
train_pa = pa_df[pa_df.index.isin(train_df.index)]
valid_pa = pa_df[pa_df.index.isin(valid_df.index)]

print('# of rows in train:', len(train_df))
print('# of rows in valid:', len(valid_df))
print('# of rows in ap_train:', len(train_ap))
print('# of rows in ap_valid:', len(valid_ap))
print('# of rows in pa_train:', len(train_pa))
print('# of rows in pa_valid:', len(valid_pa))

# Optional: Save the resulting DataFrames to CSV files
path = "data"  # Define your directory path

train_df.to_csv(os.path.join(path_splitted, 'train.csv'), index=False)
valid_df.to_csv(os.path.join(path_splitted, 'valid.csv'), index=False)

train_ap.to_csv(os.path.join(path_splitted, 'ap_train.csv'), index=False)
valid_ap.to_csv(os.path.join(path_splitted, 'ap_valid.csv'), index=False)

train_pa.to_csv(os.path.join(path_splitted, 'pa_train.csv'), index=False)
valid_pa.to_csv(os.path.join(path_splitted, 'pa_valid.csv'), index=False)


# of rows in train: 181458
# of rows in valid: 9551
# of rows in ap_train: 153556
# of rows in ap_valid: 8034
# of rows in pa_train: 27902
# of rows in pa_valid: 1517
