In [1]:
import os

# # Set the proxy environment variables
# os.environ['http_proxy'] = 'http://127.0.0.1:7890'
# os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [2]:
import ee
import datetime
import os
import itertools
import sys
import urllib.request

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import geemap
from tqdm.auto import tqdm

import subprocess
from subprocess import PIPE

In [3]:
ee.Initialize()

##### Define basic parameters

In [4]:
# define the gee-asset path 
sample_path = 'users/wangjinzhulala/China_built_up/03_sample_extract_img_value'
export_path = 'users/wangjinzhulala/China_built_up/04_sample_train_test_split'

In [5]:
# define the year range
year_range = '2020_2022'

# create the region_dict
# region_dict = dict(zip(['华东','东北','中南','华北','西北','西南'],
#                        ['huadong','dongbei','zhongnan','huabei','xibei','xinan']))
region_dict = dict(zip(['西南'],
                       ['xinan']))
region_en2cn = {v:k for k,v in region_dict.items()}

##### Using grid to extract sample points

In [7]:
# function to split all_sample into hold_sample and train_sample  
def get_hold_in_out(sample_path,region_en,year):
    
    # get all samplt pts
    sample_pt = ee.FeatureCollection(f"{sample_path}/Control_sample_ext_img_{region_en}_{year}") 
    
    # saplit all sample to built and non_built pts
    pt_built     = sample_pt.filterMetadata('Built','equals',1)
    pt_non_built = sample_pt.filterMetadata('Built','equals',0)


    #_____________________________1: select one point from each grid________________________

    # Define a spatial filter as geometries that intersect.
    spatialFilter = ee.Filter.intersects(
                                          leftField= '.geo',
                                          rightField= '.geo',
                                          maxError= 1
                                        )
    # Define a save all join.
    saveAllJoin = ee.Join.saveAll(matchesKey= 'sample_pts')
    
    # the 40km grid cell; which will be used to hold out one built_pt sample each grid
    grid_shp  =  ee.FeatureCollection("users/wangjinzhulala/China_built_up/01_Boundary_shp/China_zone_grid_40km")\
                   .filterMetadata('NAME1','equals',region_en2cn[region_en])

    # Apply the join.
    grid_intersect_bulit = saveAllJoin.apply(grid_shp, pt_built, spatialFilter)

    # select one point from each grid
    choose_one_built_pt = grid_intersect_bulit.map(lambda fe: ee.List(fe.get('sample_pts')).get(0) )

    # balance the size of choose_one sample [built = non-built]
    choose_one_non_built_pt = pt_non_built\
                              .randomColumn()\
                              .sort('random')\
                              .limit(choose_one_built_pt.size())

    # merge to get the hold_out sample
    hold_out = choose_one_built_pt.merge(choose_one_non_built_pt)  

    #_____________________________2: exclude the selected points________________________

    # using the invert join for exclusion
    invertedJoin = ee.Join.inverted()
    hold_in = invertedJoin.apply(sample_pt, hold_out, spatialFilter)
    # this process makes sure the featurecollection is exportable
    hold_in = ee.FeatureCollection(hold_in.toList(hold_in.size()))
    
    return hold_out,hold_in

In [9]:
# function to download the csv to local drive & gee-asset_______________
def download_hold_in_out(hold_out,hold_in,year,export_path=export_path):
    # construct export names
    name_hold_out = f'Grid_select_{region_en}_{year}'
    name_hold_in  = f'Training_sample_{region_en}_{year}' 

    # downloding
    for table_name,pts in zip([name_hold_out,name_hold_in],
                              [hold_out,hold_in]):

        size = hold_out.size().getInfo()

        # to loacl drive
        ulr = pts.getDownloadURL('csv') 
        urllib.request.urlretrieve(ulr, f'./Data/{table_name}.csv')

        # to GEE-asset
        task = ee.batch.Export.table.toAsset(collection=pts,
                                             description=table_name,
                                             assetId=f'{export_path}/{table_name}')

        task.start()

    # print out the process
    print(f'{region_en}_{year} of {size} points downloaded!')

In [10]:
# hold_out and export
for region_cn,region_en in tqdm(region_dict.items(),total=len(region_dict)):

    # the hold_out process 
    hold_out,hold_in = get_hold_in_out(sample_path,region_en,year_range)

    # export
    download_hold_in_out(hold_out,hold_in,year_range)

  0%|          | 0/1 [00:00<?, ?it/s]

xinan_2020_2022 of 1230 points downloaded!
