In [3]:
import os
from pathlib import Path
import random
import pickle
import geopandas as gpd
import numpy as np
from sklearn.model_selection import train_test_split

In [4]:
# Set the random seed for reproducibility
seed = 42
random.seed(seed)

In [13]:
path = "/n/home07/kayan/asm/data/filtered_labels.geojson"
data_path="/n/holyscratch01/tambe_lab/kayan/karena/images/"
label_df = gpd.read_file(path)

In [14]:
# take out any files that are not present in the image directory
dir_ids = [Path(file_name).stem for file_name in os.listdir(data_path)]
label_df = label_df[label_df["unique_id"].isin(dir_ids)]

In [10]:
countries = label_df["country"].unique()

In [22]:
countries

array(['SLE', 'COD', 'CAF', 'ZWE', 'TZA'], dtype=object)

In [21]:
for country in countries:
    # unique file path for this split
    out_path = "/n/home07/kayan/asm/data/splits/split_LOCO_"+str(country)
    
    # leave this country out as test dataset
    test_ids = label_df[label_df["country"]==country]["unique_id"].values
    train_val_df = label_df[label_df["country"]!=country]
    
    # split reamining data into train and val
    train, val = train_test_split(train_val_df,
                stratify=train_val_df["country"],
                test_size=0.2,
                random_state=seed)
    
    # get unique identifiers for each split
    train_ids = train["unique_id"].values
    val_ids = val["unique_id"].values
    
    print(f"Split for {country} with {len(train_ids)} train samples, {len(val_ids)} val samples, and {len(test_ids)} test samples.")
    
    split_ids = {"train": train_ids, "val": val_ids, "test":test_ids}
    
    # save as pickle file
    with open(out_path, 'wb') as handle:
        pickle.dump(split_ids, handle, protocol=pickle.HIGHEST_PROTOCOL)

Split for SLE with 7978 train samples, 1995 val samples, and 4668 test samples.
Split for COD with 5643 train samples, 1411 val samples, and 7587 test samples.
Split for CAF with 10872 train samples, 2719 val samples, and 1050 test samples.
Split for ZWE with 11420 train samples, 2855 val samples, and 366 test samples.
Split for TZA with 10936 train samples, 2735 val samples, and 970 test samples.
