In [5]:
from shapely.geometry import LineString
from shapely.geometry import Polygon
from shapely.geometry import shape, GeometryCollection
from shapely.ops import transform 
from geopandas import GeoDataFrame
import geopandas as gpd
import json
import pyproj
from functools import partial

import warnings

warnings.filterwarnings('ignore')

In [6]:
def add_trailing_slash(path):
    if path[-1] != '/':
        path += '/'
    return path


def create_dir(output_dir):
    # If the output folder doesn't exist, create it
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

        
def poly_area(geom):        
    geom_area = transform(
    partial(
        pyproj.transform,
        pyproj.Proj(init='EPSG:4326'),
        pyproj.Proj(
            proj='aea',
            lat_1=geom.bounds[1],
            lat_2=geom.bounds[3])),
    geom)
    
    return geom_area.area / 1000000

In [17]:
def train_test_split(regions_loc,polygons_loc,skip_poly_ids=[]):
    
    with open(regions_loc) as r, open(polygons_loc) as p:
        
        pol_groups = {}
        
        regions = json.load(r)["features"]
        polygons = json.load(p)["features"]
        
        r_col = GeometryCollection([shape(feature["geometry"]).buffer(0) for feature in regions])
        p_col = GeometryCollection([shape(feature["geometry"]).buffer(0) for feature in polygons])
        
        poly_list = skip_poly_ids.copy()
        
        print('poly_list:', poly_list)
        
        for i,poly_1 in enumerate(r_col,1):
            
            
            for j, poly_2 in enumerate(p_col,1):
                
                
                if j not in poly_list:
                    

                    if poly_1.intersects(poly_2):
                        
                        poly_list.append(j)
                        

                        if i in list(pol_groups.keys()):

                            pol_groups[i].append({j:poly_area(poly_2)})


                        else:

                            pol_groups[i] = [{j:poly_area(poly_2)}]
                        
    
        return pol_groups
    
    
    
    
    
regions_loc = "D:/canopy_data/geojsons/congo_basin_custom_regions.geojson"
polygons_loc = "D:/canopy_data/geojsons/labels.geojson"

pol_groups = train_test_split(regions_loc,polygons_loc,skip_poly_ids=[95,96,97,98])

poly_list: [95, 96, 97, 98]


In [18]:
pol_groups

{1: [{1: 206.47482676728058},
  {2: 88.97169557433214},
  {6: 130.2101420622314},
  {12: 891.1834990635282}],
 2: [{70: 124.44686326273228},
  {72: 410.69131316719375},
  {74: 1196.9796451542381}],
 5: [{3: 28.334586252297168},
  {4: 7.528261189037973},
  {5: 168.64926917058818},
  {7: 879.7143196442063},
  {9: 367.81642026806276},
  {10: 718.983272562892},
  {11: 15.311577736438252},
  {13: 139.79630502818733},
  {14: 305.7903730397029},
  {19: 467.73004857890083},
  {20: 296.10011320313174},
  {21: 332.3091624014897},
  {26: 105.51083288864231},
  {28: 800.3803582166872}],
 6: [{27: 252.49973320471142},
  {30: 190.77524205140537},
  {31: 398.31553560429444},
  {32: 489.65345272545363},
  {33: 206.56955336080165},
  {34: 24.86065064172461},
  {35: 88.02860489336733},
  {37: 160.2565547825048},
  {41: 93.7460926236144},
  {43: 44.92247032955615},
  {46: 58.803620931797134},
  {47: 68.83435923622781},
  {53: 216.7509238285581},
  {58: 299.4020515949391},
  {61: 5.167511936313324},
  {63

In [19]:
train_test = {"train":[],"test":[]}

counter = 0 

for r_id in pol_groups.keys():
    
    vals_1 = []

    for p_list in pol_groups[r_id]:
        
        vals_1.append(list(p_list.values())[0])
        
    sum_vals = sum(vals_1)
    
    r_thresh = sum_vals * .8

    pol_groups_sorted = sorted(pol_groups[r_id],key=lambda i:list(i.values())[0], reverse=True)
    
    temp_sum = 0 
    
    
    for p_list in pol_groups_sorted:
        
        counter += 1 
        
        temp_sum += list(p_list.values())[0]
        
        if temp_sum < r_thresh:
            
            train_test["train"].append(list(p_list.keys())[0]) 
            
        else:
            
            train_test["test"].append(list(p_list.keys())[0]) 
                

In [20]:
counter

97

In [21]:
len(train_test['train']) + len(train_test['test']) 

97

In [22]:
pols = []
for k,v in pol_groups.items():  
    for p in v:
        pols.append(p)

In [23]:
tot_train = 0
tot_test = 0
for pol_id in train_test["train"]:
    for pol_dict in pols:
        if pol_id in pol_dict.keys():
            tot_train += pol_dict[pol_id]

for pol_id in train_test["test"]:
    for pol_dict in pols:
        if pol_id in pol_dict.keys():
            tot_test += pol_dict[pol_id]

In [24]:
tot_train

32040.71098652535

In [25]:
tot_test

11953.443324665766

In [26]:
tot_test / (tot_train + tot_test)

0.27170526429746783

In [162]:
train_test

{'train': [1,
  2,
  6,
  70,
  72,
  3,
  4,
  5,
  7,
  9,
  10,
  11,
  13,
  14,
  19,
  20,
  21,
  26,
  27,
  28,
  30,
  31,
  32,
  33,
  34,
  35,
  37,
  41,
  43,
  46,
  47,
  53,
  58,
  61,
  63,
  64,
  69,
  76,
  77,
  8,
  15,
  16,
  17,
  18,
  22,
  23,
  24,
  22,
  23,
  25,
  29,
  32,
  36,
  38,
  39,
  40,
  42,
  44,
  45,
  48,
  55,
  59,
  66,
  68,
  71,
  75,
  78,
  83,
  85,
  87,
  85,
  88,
  89,
  96,
  22,
  49,
  50,
  51,
  52,
  54,
  56,
  57,
  60,
  80,
  81,
  90,
  91,
  92,
  93,
  94,
  95,
  97,
  98,
  99,
  100,
  62,
  65,
  67],
 'test': [12, 74, 26, 28, 84, 86, 25, 88, 89, 97, 22, 82, 101, 73, 79]}

In [168]:
def intersection(lst1, lst2): 
    lst3 = [value for value in lst1 if value in lst2] 
    return lst3 


intersection(train_test["train"],train_test["test"])

[26, 28, 22, 22, 25, 88, 89, 22, 97]

In [27]:
json_object = json.dumps(train_test)

In [28]:
import json
with open('train_test_polygons.json', 'w') as fp:
    json.dump(train_test, fp)