# Train/Test Split

Census tracts in the same county are highly correlated since they may share the same transit system and have similar walkability policy. To avoid geographic leakage, we use GroupShuffleSplit to make sure we split our data into a training set and a testing set that countain disjoint county codes.

## Setup

In [None]:
# Setup

import numpy as np
import pandas as pd

## Load Data

In [11]:
# load the clean datasets

df = pd.read_csv("clean_data.csv")
df.sample()
df.columns

Index(['census_tract', 'StCoFIPS2019', 'StAbbr', 'walkability_index',
       'Pop2018', 'HU2018', 'HH2018', 'employment_mix',
       'employment_residential_mix', 'intersection_density',
       'transit_accessibility', 'employment_mix_ranked',
       'employment_residential_mix_ranked', 'intersection_density_ranked',
       'transit_accessibility_ranked', 'median_income', 'percent_unemployed',
       'percent_below_poverty', 'percent_bachelor_and_higher',
       'percent_over_65', 'percent_commute_car', 'percent_commute_transit',
       'percent_white', 'percent_black', 'percent_native_american',
       'percent_asian', 'percent_pacific_islander', 'statedesc', 'countyname',
       'total_population', 'arthritis_crudeprev', 'arthritis_crude95ci',
       'high_blood_pressure_prevalence', 'high_blood_pressure_95ci',
       'cancer_prevalence', 'cancer_95ci', 'current_asthma_prevalence',
       'current_asthma_95ci', 'coronary_heart_disease_prevalence',
       'coronary_heart_disease_95ci'

## GroupShuffleSplit

In [14]:
from sklearn.model_selection import GroupShuffleSplit

# Configuration
county_col = "StCoFIPS2019"
test_size  = 0.20
seed       = 123

# One split grouped by county
gss = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
train_idx, test_idx = next(gss.split(df, groups=df[county_col])) # get the first (and only) split

df_train = df.iloc[train_idx].copy()
df_test  = df.iloc[test_idx].copy()

# 3) save for your team
df_train.to_csv("data_train.csv", index=False)
df_test.to_csv("data_test.csv", index=False)

print("Saved:")
print("  train_by_county.csv  ->", len(df_train), "rows,", df_train[county_col].nunique(), "counties")
print("  test_by_county.csv   ->", len(df_test),  "rows,", df_test[county_col].nunique(),  "counties")

# sense check
assert set(df_train[county_col]).isdisjoint(set(df_test[county_col]))

Saved:
  train_by_county.csv  -> 52586 rows, 2496 counties
  test_by_county.csv   -> 17572 rows, 624 counties
