In [1]:
import json
from shapely.geometry import shape, GeometryCollection

In [12]:
def get_polygons_per_region(regions_loc, polygons_loc, chip_count, skip_poly_ids=[]):
    
    with open(regions_loc) as r, open(polygons_loc) as p:
        
        pol_groups = {}
        
        regions = json.load(r)["features"]
        polygons = json.load(p)["features"]
        
        r_col = GeometryCollection([shape(feature["geometry"]).buffer(0) for feature in regions])
        p_col = GeometryCollection([shape(feature["geometry"]).buffer(0) for feature in polygons])
        
        poly_list = skip_poly_ids.copy()
        
        for i,poly_1 in enumerate(r_col,1):
            
            for j, poly_2 in enumerate(p_col,1):
                
                if (j not in poly_list) and (j in chip_count):
        
                    if poly_1.intersects(poly_2):
                        
                        poly_list.append(j)

                        if i in list(pol_groups.keys()):
                            pol_groups[i].append({j:chip_count[j]})

                        else:
                            pol_groups[i] = [{j:chip_count[j]}]
                        
        return pol_groups

In [7]:
regions_loc = "D:/canopy_data/geojsons/congo_basin_custom_regions.geojson"
polygons_loc = "D:/canopy_data/geojsons/labels.geojson"

with open('chip_count.json') as c:
    chip_count = json.load(c)
    
chip_count

{'1': 559,
 '100': 336,
 '11': 56,
 '13': 550,
 '15': 1400,
 '16': 418,
 '17': 799,
 '2': 262,
 '20': 1025,
 '21': 1152,
 '22': 1710,
 '24': 210,
 '25': 1240,
 '26': 256,
 '29': 1548,
 '3': 110,
 '30': 509,
 '32': 1591,
 '33': 845,
 '34': 100,
 '35': 304,
 '36': 792,
 '37': 899,
 '38': 500,
 '39': 1620,
 '40': 306,
 '41': 272,
 '42': 456,
 '43': 110,
 '44': 1824,
 '46': 210,
 '47': 182,
 '48': 1833,
 '5': 399,
 '53': 1185,
 '54': 90,
 '55': 3036,
 '59': 1189,
 '62': 195,
 '64': 994,
 '65': 294,
 '68': 961,
 '70': 420,
 '72': 1292,
 '73': 36,
 '74': 3658,
 '75': 4680,
 '76': 440,
 '77': 780,
 '79': 528,
 '82': 899,
 '83': 456,
 '84': 611,
 '85': 3424,
 '88': 858,
 '89': 2977,
 '90': 88,
 '91': 1184,
 '93': 211,
 '94': 48,
 '99': 456,
 '10': 2843,
 '27': 708,
 '31': 1440,
 '49': 1545,
 '52': 1411,
 '58': 832,
 '6': 450,
 '87': 48,
 '9': 1292,
 '4': 34,
 '61': 20,
 '63': 72,
 '69': 460,
 '8': 121,
 '86': 832,
 '101': 1980,
 '14': 812,
 '19': 1955,
 '45': 256,
 '56': 1233,
 '66': 1822,
 '9

In [9]:
chip_count_2 = {int(k):v for k,v in chip_count.items()}

chip_count_2

{1: 559,
 100: 336,
 11: 56,
 13: 550,
 15: 1400,
 16: 418,
 17: 799,
 2: 262,
 20: 1025,
 21: 1152,
 22: 1710,
 24: 210,
 25: 1240,
 26: 256,
 29: 1548,
 3: 110,
 30: 509,
 32: 1591,
 33: 845,
 34: 100,
 35: 304,
 36: 792,
 37: 899,
 38: 500,
 39: 1620,
 40: 306,
 41: 272,
 42: 456,
 43: 110,
 44: 1824,
 46: 210,
 47: 182,
 48: 1833,
 5: 399,
 53: 1185,
 54: 90,
 55: 3036,
 59: 1189,
 62: 195,
 64: 994,
 65: 294,
 68: 961,
 70: 420,
 72: 1292,
 73: 36,
 74: 3658,
 75: 4680,
 76: 440,
 77: 780,
 79: 528,
 82: 899,
 83: 456,
 84: 611,
 85: 3424,
 88: 858,
 89: 2977,
 90: 88,
 91: 1184,
 93: 211,
 94: 48,
 99: 456,
 10: 2843,
 27: 708,
 31: 1440,
 49: 1545,
 52: 1411,
 58: 832,
 6: 450,
 87: 48,
 9: 1292,
 4: 34,
 61: 20,
 63: 72,
 69: 460,
 8: 121,
 86: 832,
 101: 1980,
 14: 812,
 19: 1955,
 45: 256,
 56: 1233,
 66: 1822,
 92: 4994,
 51: 154,
 81: 16577,
 12: 2538,
 18: 2091,
 23: 960,
 28: 1702,
 50: 2450,
 57: 2544,
 60: 7839,
 67: 3657,
 7: 2322,
 71: 81,
 78: 3888,
 80: 1656}

In [11]:
len(chip_count_2)

97

In [13]:
pol_groups = get_polygons_per_region(regions_loc, polygons_loc, chip_count_2)

pol_groups

{1: [{1: 559}, {2: 262}, {6: 450}, {12: 2538}],
 2: [{70: 420}, {72: 1292}, {74: 3658}],
 5: [{3: 110},
  {4: 34},
  {5: 399},
  {7: 2322},
  {9: 1292},
  {10: 2843},
  {11: 56},
  {13: 550},
  {14: 812},
  {19: 1955},
  {20: 1025},
  {21: 1152},
  {26: 256},
  {28: 1702}],
 6: [{27: 708},
  {30: 509},
  {31: 1440},
  {32: 1591},
  {33: 845},
  {34: 100},
  {35: 304},
  {37: 899},
  {41: 272},
  {43: 110},
  {46: 210},
  {47: 182},
  {53: 1185},
  {58: 832},
  {61: 20},
  {63: 72},
  {64: 994},
  {69: 460},
  {76: 440},
  {77: 780},
  {84: 611},
  {86: 832}],
 9: [{8: 121},
  {15: 1400},
  {16: 418},
  {17: 799},
  {18: 2091},
  {22: 1710},
  {23: 960},
  {24: 210},
  {25: 1240}],
 10: [{29: 1548},
  {36: 792},
  {38: 500},
  {39: 1620},
  {40: 306},
  {42: 456},
  {44: 1824},
  {45: 256},
  {48: 1833},
  {55: 3036},
  {59: 1189},
  {66: 1822},
  {68: 961},
  {71: 81},
  {75: 4680},
  {78: 3888},
  {83: 456},
  {85: 3424},
  {87: 48},
  {88: 858},
  {89: 2977}],
 14: [{49: 1545},
  {50

In [18]:
sum_chips = 0 

for k1 in pol_groups.keys():
    
    for k2 in pol_groups[k1]: 
        
        for v1 in k2.values():
            
            sum_chips += v1
            
sum_chips

128992

In [21]:
train_test = {"train": [], "test": []}

counter = 0 

for r_id in pol_groups.keys():
    
    vals_1 = []

    for p_list in pol_groups[r_id]:
        
        vals_1.append(list(p_list.values())[0])
        
    sum_vals = sum(vals_1)
    
    r_thresh = sum_vals * .84

    pol_groups_sorted = sorted(pol_groups[r_id],key=lambda i:list(i.values())[0],reverse=True)
    
    temp_sum = 0 
    
    
    for p_list in pol_groups_sorted:
        
        counter += 1 
        
        temp_sum += list(p_list.values())[0]
        
        if temp_sum < r_thresh:
            
            train_test["train"].append(list(p_list.keys())[0]) 
            
        else:
            
            train_test["test"].append(list(p_list.keys())[0]) 

In [22]:
train_test

{'train': [12,
  1,
  74,
  10,
  7,
  19,
  28,
  9,
  21,
  32,
  31,
  53,
  64,
  37,
  33,
  58,
  86,
  77,
  27,
  84,
  30,
  18,
  22,
  15,
  25,
  23,
  75,
  78,
  85,
  55,
  89,
  48,
  44,
  66,
  39,
  29,
  81,
  60,
  57,
  50,
  92,
  101],
 'test': [6,
  2,
  72,
  70,
  20,
  14,
  13,
  5,
  26,
  3,
  11,
  4,
  69,
  76,
  35,
  41,
  46,
  47,
  43,
  34,
  63,
  61,
  17,
  16,
  24,
  8,
  59,
  68,
  88,
  36,
  38,
  42,
  83,
  40,
  45,
  71,
  87,
  80,
  49,
  52,
  56,
  82,
  51,
  54,
  91,
  99,
  100,
  93,
  90,
  94,
  67,
  65,
  62,
  73,
  79]}

In [23]:
pols = []
for k,v in pol_groups.items():  
    for p in v:
        pols.append(p)
        
pols

[{1: 559},
 {2: 262},
 {6: 450},
 {12: 2538},
 {70: 420},
 {72: 1292},
 {74: 3658},
 {3: 110},
 {4: 34},
 {5: 399},
 {7: 2322},
 {9: 1292},
 {10: 2843},
 {11: 56},
 {13: 550},
 {14: 812},
 {19: 1955},
 {20: 1025},
 {21: 1152},
 {26: 256},
 {28: 1702},
 {27: 708},
 {30: 509},
 {31: 1440},
 {32: 1591},
 {33: 845},
 {34: 100},
 {35: 304},
 {37: 899},
 {41: 272},
 {43: 110},
 {46: 210},
 {47: 182},
 {53: 1185},
 {58: 832},
 {61: 20},
 {63: 72},
 {64: 994},
 {69: 460},
 {76: 440},
 {77: 780},
 {84: 611},
 {86: 832},
 {8: 121},
 {15: 1400},
 {16: 418},
 {17: 799},
 {18: 2091},
 {22: 1710},
 {23: 960},
 {24: 210},
 {25: 1240},
 {29: 1548},
 {36: 792},
 {38: 500},
 {39: 1620},
 {40: 306},
 {42: 456},
 {44: 1824},
 {45: 256},
 {48: 1833},
 {55: 3036},
 {59: 1189},
 {66: 1822},
 {68: 961},
 {71: 81},
 {75: 4680},
 {78: 3888},
 {83: 456},
 {85: 3424},
 {87: 48},
 {88: 858},
 {89: 2977},
 {49: 1545},
 {50: 2450},
 {51: 154},
 {52: 1411},
 {54: 90},
 {56: 1233},
 {57: 2544},
 {60: 7839},
 {80: 1656

In [24]:
tot_train = 0
tot_test = 0
for pol_id in train_test["train"]:
    for pol_dict in pols:
        if pol_id in pol_dict.keys():
            tot_train += pol_dict[pol_id]

for pol_id in train_test["test"]:
    for pol_dict in pols:
        if pol_id in pol_dict.keys():
            tot_test += pol_dict[pol_id]

In [25]:
tot_test / (tot_train + tot_test)

0.22720788886132473

In [26]:
import json
with open('train_test_polygons.json', 'w') as fp:
    json.dump(train_test, fp)