In [1]:
import pickle
import geopandas as gpd
import numpy as np
from sklearn.model_selection import train_test_split

In [1]:
def split_data(
    path="/n/home07/kayan/asm/data/filtered_labels.geojson", 
    stratify_col="country", 
    save = True,
    out_path = "/n/home07/kayan/asm/data/train_test_split",
    n = None
):
    
    label_df = gpd.read_file(path)
    if n is not None:
        label_df = label_df.head(n)
    
    # split into train/val and test
    train, test = train_test_split(label_df, 
                stratify=label_df[stratify_col] if stratify_col is not None else None,
                test_size=0.2
            )
    # split further into train and val
    train, val = train_test_split(train,
                stratify=train[stratify_col] if stratify_col is not None else None,
                test_size=0.2)
                                  
    # get unique identifiers for each split
    train_ids = train["unique_id"].values
    val_ids = val["unique_id"].values
    test_ids = test["unique_id"].values
    print(f"Split with {len(train_ids)} train images, {len(val_ids)} validation images, and {len(test_ids)} test images")
    print(f"Mine proportions\n Train: {len(train[train['label']==1.0])/len(train)}")
    print(f" Validation: {len(val[val['label']==1.0])/len(val)}")
    print(f" Test: {len(test[test['label']==1.0])/len(test)}")
    split_ids = {"train": train_ids, "val": val_ids, "test":test_ids}
    
    if save:
        # save as pickle file
        with open(out_path, 'wb') as handle:
            pickle.dump(split_ids, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return split_ids

In [4]:
split_data(n=100)

Split with 64 train images, 16 validation images, and 20 test images
Mine proportions
 Train: 1.0
 Validation: 1.0
 Test: 1.0


{'train': array(['lat_9--045__lon_-12--115', 'lat_8--925__lon_-12--835',
        'lat_7--765__lon_-11--875', 'lat_9--555__lon_-12--235',
        'lat_8--965__lon_-11--885', 'lat_8--665__lon_-11--885',
        'lat_8--605__lon_-11--965', 'lat_8--415__lon_-11--865',
        'lat_8--605__lon_-11--975', 'lat_9--445__lon_-12--005',
        'lat_7--645__lon_-11--965', 'lat_8--615__lon_-11--955',
        'lat_7--665__lon_-11--965', 'lat_9--425__lon_-12--025',
        'lat_9--425__lon_-12--035', 'lat_8--655__lon_-11--925',
        'lat_9--395__lon_-12--035', 'lat_9--535__lon_-12--235',
        'lat_9--445__lon_-12--015', 'lat_9--175__lon_-12--865',
        'lat_7--645__lon_-11--975', 'lat_8--595__lon_-11--975',
        'lat_7--685__lon_-11--905', 'lat_7--745__lon_-11--935',
        'lat_8--645__lon_-11--935', 'lat_7--675__lon_-12--055',
        'lat_9--545__lon_-12--205', 'lat_7--665__lon_-11--975',
        'lat_8--515__lon_-11--915', 'lat_9--565__lon_-12--185',
        'lat_7--665__lon_-11--9

In [None]:
with open("/n/home07/kayan/asm/data/train_test_split", 'rb') as handle:
    b = pickle.load(handle)