In [27]:
import os
import tempfile
from urllib.parse import urlparse

import shutil
from shapely import Point
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
# import planetary_computer
# import pystac
import torch
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler

from rasterio.transform import from_bounds

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 12)

In [2]:
def get_region(shape, key, point):
    for row_id, row in shape.iterrows():
        if row['geometry'].intersects(point):
            return row[key]
        
    return 'unknown'

In [17]:
shp_file = './karnataka_footprints/DISTRICT_BOUNDARY.shp'
data = gpd.read_file(shp_file).to_crs(4326)
data = data[data['STATE'] == 'KARN>TAKA']

In [9]:
root = '../SpeciesMapping/DATA/karnataka_top_10_gsv/'
csv = pd.read_csv(f'{root}/tree_metadata.csv')
csv = csv[csv['image_filename'].notnull()]
print(csv)

                                 tree_id             location_id  \
0             Davanagere_0_1_northwest_0          Davanagere_0_1   
1             Davanagere_0_5_northeast_0          Davanagere_0_5   
3             Davanagere_0_5_southeast_0          Davanagere_0_5   
4                 Davanagere_0_5_south_0          Davanagere_0_5   
5                  Davanagere_0_5_west_0          Davanagere_0_5   
...                                  ...                     ...   
7543                 Kolar_102_4_north_0             Kolar_102_4   
7544                 Kolar_102_4_north_1             Kolar_102_4   
7545             Kolar_103_2_northeast_0             Kolar_103_2   
7546       Dakshina_Kannada_103_3_east_0  Dakshina_Kannada_103_3   
7547  Dakshina_Kannada_103_3_southwest_0  Dakshina_Kannada_103_3   

                source  original_lat  original_lon    gsv_lat    gsv_lon  \
0           Davanagere     14.455391     75.882939  14.455417  75.882925   
1           Davanagere     14.3

In [21]:
district_to_zone = {
    'Shimoga': 'SOUTHERN TRANSITION', 
    'Uttara Kannada': 'HILL', 
    'Davanagere': 'CENTRAL DRY', 
    'Chitradurga': 'CENTRAL DRY', 
    'Ballari': 'NORTH EAST DRY', 
    'Dharwad': 'WESTERN TRANSITION', 
    'Gadag': 'NORTHERN DRY', 
    'Koppal': 'NORTH EAST DRY', 
    'Raichur': 'NORTH EAST DRY', 
    'Mandya': 'SOUTHERN DRY', 
    'Dakshina Kannada': 'COASTAL', 
    'Hassan': 'SOUTHERN TRANSITION', 
    'Kolar': 'EASTERN DRY', 
    'Bengaluru South': 'EASTERN DRY', 
    'Udupi': 'COASTAL', 
    'Tumakuru': 'CENTRAL DRY', 
    'Chikkaballapura': 'EASTERN DRY',
    'Vijayanagara': 'NORTHERN DRY'
}

In [36]:
root = '../SpeciesMapping/DATA/karnataka_top_10_gsv/'
csv = pd.read_csv(f'{root}/tree_metadata.csv')
csv = csv[csv['image_filename'].notnull()]

by_zone = {}
for row_id, row in csv.iterrows():
    zone = district_to_zone[row['source']]
    by_zone[zone] = by_zone.get(zone, 0)+1
print(by_zone)

{'CENTRAL DRY': 1430, 'EASTERN DRY': 1956, 'SOUTHERN TRANSITION': 995, 'NORTHERN DRY': 530, 'COASTAL': 542, 'NORTH EAST DRY': 439, 'WESTERN TRANSITION': 1108, 'SOUTHERN DRY': 261, 'HILL': 12}


In [37]:
root = '../SpeciesMapping/DATA/karnataka_top_25_subset_gsv/'
csv = pd.read_csv(f'{root}/tree_metadata.csv')
csv = csv[csv['image_filename'].notnull()]

for row_id, row in csv.iterrows():
    zone = district_to_zone[row['source']]
    by_zone[zone] = by_zone.get(zone, 0)+1
print(by_zone)

{'CENTRAL DRY': 1430, 'EASTERN DRY': 1956, 'SOUTHERN TRANSITION': 995, 'NORTHERN DRY': 1377, 'COASTAL': 1273, 'NORTH EAST DRY': 439, 'WESTERN TRANSITION': 2038, 'SOUTHERN DRY': 775, 'HILL': 47}


In [26]:
wanted_bucket_tiers = [
    {
        'CENTRAL DRY': 50, 
        'EASTERN DRY': 50, 
        'SOUTHERN TRANSITION': 50, 
        'NORTHERN DRY': 50, 
        'COASTAL': 50, 
        'NORTH EAST DRY': 50, 
        'WESTERN TRANSITION': 50, 
        'SOUTHERN DRY': 50, 
        'HILL': 47
    },
    {
        'CENTRAL DRY': 100, 
        'EASTERN DRY': 100, 
        'SOUTHERN TRANSITION': 100, 
        'NORTHERN DRY': 100, 
        'COASTAL': 100, 
        'NORTH EAST DRY': 100, 
        'WESTERN TRANSITION': 100, 
        'SOUTHERN DRY': 100
    },
    {
        'CENTRAL DRY': 100, 
        'EASTERN DRY': 100, 
        'SOUTHERN TRANSITION': 100, 
        'NORTHERN DRY': 100, 
        'COASTAL': 100, 
        'NORTH EAST DRY': 100, 
        'WESTERN TRANSITION': 100, 
        'SOUTHERN DRY': 100
    },
    {
        'CENTRAL DRY': 100, 
        'EASTERN DRY': 100, 
        'SOUTHERN TRANSITION': 100, 
        'NORTHERN DRY': 100, 
        'COASTAL': 100, 
        'NORTH EAST DRY': 100, 
        'WESTERN TRANSITION': 100, 
        'SOUTHERN DRY': 100
    },
    {
        'CENTRAL DRY': 100, 
        'EASTERN DRY': 100, 
        'SOUTHERN TRANSITION': 100, 
        'NORTHERN DRY': 100, 
        'COASTAL': 100, 
        'NORTH EAST DRY': 100, 
        'WESTERN TRANSITION': 100, 
        'SOUTHERN DRY': 100
    }
]

In [29]:
def group_csv(csv):
    for row_id, row in csv.iterrows():
        zone = district_to_zone[row['source']]
        by_zone[zone] = by_zone.get(zone, 0)+1

        written = False
        for i in range(5):
            if wanted_bucket_tiers[i][zone] > 0:
                written = True
                os.makedirs(f'cvat_uploads/group_{i}', exist_ok=True)
                shutil.copy(f'{root}/images/{row['image_filename']}', f'cvat_uploads/group_{i}/{row['image_filename']}')
                wanted_bucket_tiers[i][zone] -= 1
                break

        if not written:
            os.makedirs(f'cvat_uploads/{zone}', exist_ok=True)
            shutil.copy(f'{root}/images/{row['image_filename']}', f'cvat_uploads/{zone}/{row['image_filename']}')

root = '../SpeciesMapping/DATA/karnataka_top_10_gsv/'
csv = pd.read_csv(f'{root}/tree_metadata.csv')
csv = csv[csv['image_filename'].notnull()]
group_csv(csv)

root = '../SpeciesMapping/DATA/karnataka_top_25_subset_gsv/'
csv = pd.read_csv(f'{root}/tree_metadata.csv')
csv = csv[csv['image_filename'].notnull()]
group_csv(csv)

In [38]:
print(by_zone)
tot = 0
for v in by_zone.values():
    tot += v
print(tot)

{'CENTRAL DRY': 1430, 'EASTERN DRY': 1956, 'SOUTHERN TRANSITION': 995, 'NORTHERN DRY': 1377, 'COASTAL': 1273, 'NORTH EAST DRY': 439, 'WESTERN TRANSITION': 2038, 'SOUTHERN DRY': 775, 'HILL': 47}
10330
