In [3]:
import pickle
import geopandas as gpd
import numpy as np
from sklearn.model_selection import train_test_split

In [1]:
def split_data(
    path="/n/home07/kayan/asm/data/filtered_labels.geojson", 
    stratify_col="country", 
    save = True,
    out_path = "/n/home07/kayan/asm/data/train_test_split",
    n = None
):
    
    label_df = gpd.read_file(path)
    if n is not None:
        label_df = label_df.head(n)
    
    # split into train/val and test
    train, test = train_test_split(label_df, 
                stratify=label_df[stratify_col] if stratify_col is not None else None,
                test_size=0.2
            )
    # split further into train and val
    train, val = train_test_split(train,
                stratify=train[stratify_col] if stratify_col is not None else None,
                test_size=0.2)
                                  
    # get unique identifiers for each split
    train_ids = train["unique_id"].values
    val_ids = val["unique_id"].values
    test_ids = test["unique_id"].values
    print(f"Split with {len(train_ids)} train images, {len(val_ids)} validation images, and {len(test_ids)} test images")
    print(f"Mine proportions\n Train: {len(train[train['label']==1.0])/len(train)}")
    print(f" Validation: {len(val[val['label']==1.0])/len(val)}")
    print(f" Test: {len(test[test['label']==1.0])/len(test)}")
    split_ids = {"train": train_ids, "val": val_ids, "test":test_ids}
    
    if save:
        # save as pickle file
        with open(out_path, 'wb') as handle:
            pickle.dump(split_ids, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return split_ids

In [4]:
split_data(n=100)

Split with 64 train images, 16 validation images, and 20 test images
Mine proportions
 Train: 1.0
 Validation: 1.0
 Test: 1.0


{'train': array(['lat_9--045__lon_-12--115', 'lat_8--925__lon_-12--835',
        'lat_7--765__lon_-11--875', 'lat_9--555__lon_-12--235',
        'lat_8--965__lon_-11--885', 'lat_8--665__lon_-11--885',
        'lat_8--605__lon_-11--965', 'lat_8--415__lon_-11--865',
        'lat_8--605__lon_-11--975', 'lat_9--445__lon_-12--005',
        'lat_7--645__lon_-11--965', 'lat_8--615__lon_-11--955',
        'lat_7--665__lon_-11--965', 'lat_9--425__lon_-12--025',
        'lat_9--425__lon_-12--035', 'lat_8--655__lon_-11--925',
        'lat_9--395__lon_-12--035', 'lat_9--535__lon_-12--235',
        'lat_9--445__lon_-12--015', 'lat_9--175__lon_-12--865',
        'lat_7--645__lon_-11--975', 'lat_8--595__lon_-11--975',
        'lat_7--685__lon_-11--905', 'lat_7--745__lon_-11--935',
        'lat_8--645__lon_-11--935', 'lat_7--675__lon_-12--055',
        'lat_9--545__lon_-12--205', 'lat_7--665__lon_-11--975',
        'lat_8--515__lon_-11--915', 'lat_9--565__lon_-12--185',
        'lat_7--665__lon_-11--9

In [None]:
with open("/n/home07/kayan/asm/data/train_test_split", 'rb') as handle:
    b = pickle.load(handle)

In [2]:
path="/n/home07/kayan/asm/data/filtered_labels.geojson"
stratify_col="country"

label_df = gpd.read_file(path)

In [6]:
label_df[label_df["label"] == 1]

Unnamed: 0,lon,lat,unique_id,country,input_id,sample_type,mine_type,confidence,label,area_km2,proportion_inspected,proportion_mining,geometry
0,-12.885,9.155,lat_9--155__lon_-12--885,SLE,SLE-1279CL,IPIS,artisanal,3.0,1.0,0.000656,0.523869,0.000523,"POLYGON ((-12.88452 9.15891, -12.88457 9.15875..."
1,-12.875,9.175,lat_9--175__lon_-12--875,SLE,SLE-184CL,IPIS,artisanal,4.0,1.0,0.000587,0.350198,0.000468,"MULTIPOLYGON (((-12.87024 9.17627, -12.87035 9..."
2,-12.865,9.175,lat_9--175__lon_-12--865,SLE,SLE-184CL,IPIS,artisanal,4.0,1.0,0.003670,0.384139,0.002924,"POLYGON ((-12.86941 9.17610, -12.86930 9.17596..."
3,-12.855,9.125,lat_9--125__lon_-12--855,SLE,negatives-1428,CLU,artisanal,3.0,1.0,0.001271,0.784137,0.001013,"POLYGON ((-12.85544 9.12746, -12.85553 9.12733..."
4,-12.835,8.925,lat_8--925__lon_-12--835,SLE,negatives-638,CLU,artisanal,3.0,1.0,0.001501,0.784137,0.001197,"POLYGON ((-12.83892 8.92752, -12.83883 8.92757..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2947,34.655,-1.965,lat_-1--965__lon_34--655,TZA,Tanzania-69CL,IPIS,artisanal,4.0,1.0,0.005781,0.391740,0.004663,"MULTIPOLYGON (((34.65970 -1.96575, 34.65971 -1..."
2948,34.665,-1.965,lat_-1--965__lon_34--665,TZA,Tanzania-69CL,IPIS,artisanal,4.0,1.0,0.005814,0.386872,0.004689,"MULTIPOLYGON (((34.66061 -1.96512, 34.66044 -1..."
2949,35.795,-6.845,lat_-6--845__lon_35--795,TZA,negatives-7830,UAR,artisanal,3.0,1.0,0.004358,0.392069,0.003492,"MULTIPOLYGON (((35.79196 -6.84617, 35.79202 -6..."
2950,35.855,-4.015,lat_-4--015__lon_35--855,TZA,negatives-7911,UAR,artisanal,3.0,1.0,0.029574,0.784137,0.023807,"MULTIPOLYGON (((35.85781 -4.01795, 35.85773 -4..."
