In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupShuffleSplit

In [3]:
PartMC_data = pd.read_csv('/data/keeling/a/xx24/e/proj_ml/code_ml_surfactant_ccn/data/merged_data.csv')

columns_to_multiply = ['SO4', 'Na', 'Cl', 'NH4', 'NO3', 'OC', 'SOA']
columns_to_divide = ['CCN_CS_0.1', 'CCN_CS_0.15', 'CCN_CS_0.2', 'CCN_CS_0.25', 'CCN_CS_0.3', 'CCN_CS_0.4', 'CCN_CS_0.6', 'CCN_CS_0.8', 'CCN_CS_1.0', 
                     'CCN_VS_0.1', 'CCN_VS_0.15', 'CCN_VS_0.2', 'CCN_VS_0.25', 'CCN_VS_0.3', 'CCN_VS_0.4', 'CCN_VS_0.6', 'CCN_VS_0.8', 'CCN_VS_1.0',
                     'Num_Conc']
PartMC_data[columns_to_multiply] = PartMC_data[columns_to_multiply] * 1e9 # Convert to ug/m3
PartMC_data[columns_to_divide] = PartMC_data[columns_to_divide] * 1e-6 # Convert to #/cm3 
PartMC_data['Bulk_DP'] = PartMC_data['Bulk_DP']*1e6 # convert to um

In [4]:
PartMC_data['OA'] = PartMC_data['OC'] + PartMC_data['SOA']

In [5]:
PartMC_data.head()

Unnamed: 0,Global_ID,DayofYear,Time_hr,Temperature,RH,Pressure,Density,Bulk_DP,BC,OC,...,CCN_VS_0.1,CCN_VS_0.15,CCN_VS_0.2,CCN_VS_0.25,CCN_VS_0.3,CCN_VS_0.4,CCN_VS_0.6,CCN_VS_0.8,CCN_VS_1.0,OA
0,1,116,1,283.512,0.5398,100000.0,1.228736,0.152147,8.212753e-11,14.849793,...,4.393146,23.236317,39.971849,62.671743,90.753696,142.039871,170.976377,178.892829,179.822475,14.851338
1,1,116,2,283.512,0.5398,100000.0,1.228736,0.159248,1.195284e-10,22.459556,...,158.038764,228.278524,243.033744,254.813585,260.599394,265.719094,266.94708,267.330619,267.377005,22.470256
2,1,116,3,283.512,0.5398,100000.0,1.228736,0.178122,1.346392e-10,26.319633,...,221.337726,290.780047,299.068844,303.015667,305.43734,307.936013,308.6115,308.676534,308.685859,26.348357
3,1,116,4,283.512,0.5398,100000.0,1.228736,0.195206,1.550557e-10,29.973418,...,287.096055,348.133693,354.525344,358.642348,360.327751,361.310519,361.533768,361.789294,361.789294,30.031067
4,1,116,5,283.512,0.5398,100000.0,1.228736,0.209752,1.856348e-10,36.333209,...,378.84403,430.1512,434.75856,438.060016,440.008467,441.526181,442.013976,442.069904,442.069904,36.441919


In [6]:
missing_values = PartMC_data.isnull().sum()
print(missing_values)

Global_ID      0
DayofYear      0
Time_hr        0
Temperature    0
RH             0
Pressure       0
Density        0
Bulk_DP        0
BC             0
OC             0
OIN            0
Na             0
Cl             0
SO4            0
NO3            0
NH4            0
SOA            0
Num_Conc       0
CCN_CS_0.1     0
CCN_CS_0.15    0
CCN_CS_0.2     0
CCN_CS_0.25    0
CCN_CS_0.3     0
CCN_CS_0.4     0
CCN_CS_0.6     0
CCN_CS_0.8     0
CCN_CS_1.0     0
CCN_VS_0.1     0
CCN_VS_0.15    0
CCN_VS_0.2     0
CCN_VS_0.25    0
CCN_VS_0.3     0
CCN_VS_0.4     0
CCN_VS_0.6     0
CCN_VS_0.8     0
CCN_VS_1.0     0
OA             0
dtype: int64


In [7]:
splitter_test = GroupShuffleSplit(test_size=0.15, n_splits=1, random_state=42)
train_val_idxs, test_idxs = next(splitter_test.split(PartMC_data, groups=PartMC_data['Global_ID']))
train_val = PartMC_data.iloc[train_val_idxs]  # 85% data
test = PartMC_data.iloc[test_idxs]            # 15% data

splitter_val = GroupShuffleSplit(test_size=0.1765, n_splits=1, random_state=42)  # 0.1765 ≈ 15%/85%
train_idxs, val_idxs = next(splitter_val.split(train_val, groups=train_val['Global_ID']))
train = train_val.iloc[train_idxs]  # 70% data
val = train_val.iloc[val_idxs]      # 15% data

print(f"Train: {len(train)} ({len(train)/len(PartMC_data):.1%})")
print(f"Valid: {len(val)} ({len(val)/len(PartMC_data):.1%})")
print(f"Test: {len(test)} ({len(test)/len(PartMC_data):.1%})")

Train: 33576 (70.0%)
Valid: 7224 (15.0%)
Test: 7200 (15.0%)


In [8]:
train_ids = set(train['Global_ID'])
val_ids = set(val['Global_ID'])
test_ids = set(test['Global_ID'])

train_val_overlap = train_ids.intersection(val_ids)
train_test_overlap = train_ids.intersection(test_ids)
val_test_overlap = val_ids.intersection(test_ids)

all_overlap = train_ids & val_ids & test_ids
if all_overlap:
    print("❌ Critical Error: Global_IDs present in TRAIN, VAL, and TEST:", all_overlap)
else:
    print("✅ Final Check: No Global_IDs appear in all three sets")

✅ Final Check: No Global_IDs appear in all three sets


In [9]:
train.to_csv('/data/keeling/a/xx24/e/proj_ml/code_ml_surfactant_ccn/data/partmc_train.csv', index=False)
test.to_csv('/data/keeling/a/xx24/e/proj_ml/code_ml_surfactant_ccn/data/partmc_test.csv', index=False)
val.to_csv('/data/keeling/a/xx24/e/proj_ml/code_ml_surfactant_ccn/data/partmc_valid.csv', index=False)