In [12]:
import os
import tempfile
from urllib.parse import urlparse

import geopandas as gpd
import matplotlib.pyplot as plt
# import planetary_computer
# import pystac
import torch
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler

from rasterio.transform import from_bounds

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 12)

In [2]:
class Skysat(RasterDataset):
    filename_glob = "*.tif"
    filename_regex = r"()"
    date_format = "%Y%m%d"
    is_image = True
    separate_files = True
    all_bands = ["red", "green", "blue"]
    rgb_bands = ["red", "green", "blue"]
    
    def __getitem__(self, index):
        sample_dict = super().__getitem__(index)

        # Convert image to tensor here, so that we can handle the bands
        # Compute the affine transform for this image
        image = sample_dict['image']

        transform = self.get_transform(sample_dict["bounds"], image.shape[1], image.shape[2])
        sample_dict["transform"] = transform
        return sample_dict

    def get_transform(self, bounds, height, width):
        # Convert a BoundingBox in pixel coordinates to an affine transform
        left, top, right, bottom = bounds.minx, bounds.maxy, bounds.maxx, bounds.miny
        return from_bounds(left, bottom, right, top, width, height)
    
    def plot(self, sample):
#         print(sample)
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample["image"][rgb_indices].permute(1, 2, 0).numpy().astype(int)
#         image = torch.clamp(image / 10000, min=0, max=1).numpy()
#         print(image)
        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig

In [32]:
dataset = Skysat('./tifs_done/')
print(dataset)

Skysat Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=-124.13750190128431, maxx=-114.19532113749236, miny=32.80281803283856, maxy=41.93959626234747, mint=0.0, maxt=9.223372036854776e+18)
    size: 89


In [13]:
cali = './ca_state/CA_State.shp'
cali_borders = gpd.read_file(cali).to_crs(4269)

shp_file = './phzm_us_zones_shp_2023/phzm_us_zones_shp_2023.shp'
data = gpd.read_file(shp_file)

intersection = data.overlay(cali_borders).to_crs(4326)

In [29]:
def get_region(shape, key, point):
    for row_id, row in shape.iterrows():
        if row['geometry'].intersects(point):
            return row[key]
        
    return 'unknown'

In [33]:
from PIL import Image
import torch
import random
import string
import rasterio
from shapely import Point

torch.manual_seed(3)
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
c = 0
zones = {}
for i, batch in enumerate(dataloader):
    unbinded = unbind_samples(batch)
    sample = unbinded[0]
    im = sample['image']
    numpixels = im.shape[0] * im.shape[1] * im.shape[2]
    numzero = (im == 0).sum().item()
    ratio = numzero / numpixels
    if ratio < 0.2:
        c += 1
        bds = sample['bounds']
        pt = Point((bds[0] + bds[1]) / 2, (bds[2] + bds[3]) / 2)
        zone = get_region(intersection, 'zone', pt)
        if zone not in zones:
            zones[zone] = 0
        zones[zone] += 1

print(c)
print(zones)

10775
{'10a': 1152, '10b': 885, '11a': 20, '6a': 1082, '6b': 1079, '5b': 24, '7a': 1204, '7b': 970, '8a': 1115, '8b': 1042, '9a': 1080, '9b': 1122}


In [44]:
dataset = Skysat('./NUTMGS/filtered_tifs_for_labeling/')
print(dataset)

shp_file = './karnataka_footprints/DISTRICT_BOUNDARY.shp'
data = gpd.read_file(shp_file).to_crs(32643)
data = data[data['STATE'] == 'KARN>TAKA']

Skysat Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=433875.0, maxx=875573.6321397871, miny=1361800.0, maxy=1729841.5, mint=0.0, maxt=9.223372036854776e+18)
    size: 141


In [46]:
from PIL import Image
import torch
import random
import string
import rasterio
from shapely import Point

torch.manual_seed(3)
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
c = 0
zones = {}
for i, batch in enumerate(dataloader):
    unbinded = unbind_samples(batch)
    sample = unbinded[0]
    im = sample['image']
    numpixels = im.shape[0] * im.shape[1] * im.shape[2]
    numzero = (im == 0).sum().item()
    ratio = numzero / numpixels
    if ratio < 0.2:
        c += 1
        bds = sample['bounds']
        pt = Point((bds[0] + bds[1]) / 2, (bds[2] + bds[3]) / 2)
        zone = get_region(data, 'District', pt)
        if zone not in zones:
            zones[zone] = 0
        zones[zone] += 1

print(c)
print(zones)

8540
{'SHIVAMOGGA': 721, 'UTTARA  KANNADA': 178, 'D>VANGERE': 1260, 'CHITRADURGA': 299, 'BALL>RI': 661, 'DH>RWAD': 414, 'GADAG': 588, 'KOPPAL': 559, 'R>ICH@R': 324, 'KODAGU': 48, 'MANDYA': 419, 'R>MANAGARAM': 411, 'DAKSHINA  KANNADA': 443, 'HASSAN': 543, 'KOLAR': 247, 'BENGAL@RU RURAL': 35, 'UDUPI': 474, 'TUMAK@RU': 323, 'CHIK BALL>PUR': 593}


In [48]:
district_to_zone = {
    'SHIVAMOGGA': 'SOUTHERN TRANSITION', 
    'UTTARA  KANNADA': 'HILL', 
    'D>VANGERE': 'CENTRAL DRY', 
    'CHITRADURGA': 'CENTRAL DRY', 
    'BALL>RI': 'NORTH EAST DRY', 
    'DH>RWAD': 'WESTERN TRANSITION', 
    'GADAG': 'NORTHERN DRY', 
    'KOPPAL': 'NORTH EAST DRY', 
    'R>ICH@R': 'NORTH EAST DRY', 
    'KODAGU': 'SOUTHERN DRY', 
    'MANDYA': 'SOUTHERN DRY', 
    'R>MANAGARAM': 'EASTERN DRY', 
    'DAKSHINA  KANNADA': 'COASTAL', 
    'HASSAN': 'SOUTHERN TRANSITION', 
    'KOLAR': 'EASTERN DRY', 
    'BENGAL@RU RURAL': 'EASTERN DRY', 
    'UDUPI': 'COASTAL', 
    'TUMAK@RU': 'CENTRAL DRY', 
    'CHIK BALL>PUR': 'EASTERN DRY'
}

by_zone = {}
by_zone['NORTH EAST TRANSITION'] = 0
for z in zones:
    zone = district_to_zone[z]
    if zone not in by_zone:
        by_zone[zone] = 0

    by_zone[zone] += zones[z]

print(len(by_zone))
print(by_zone)

10
{'NORTH EAST TRANSITION': 0, 'SOUTHERN TRANSITION': 1264, 'HILL': 178, 'CENTRAL DRY': 1882, 'NORTH EAST DRY': 1544, 'WESTERN TRANSITION': 414, 'NORTHERN DRY': 588, 'SOUTHERN DRY': 467, 'EASTERN DRY': 1286, 'COASTAL': 917}


In [53]:
dirs = os.listdir('../Treework/filtered_imported_data/')
dirs = [dir for dir in dirs if dir.count('RajasthanTrees') > 0]
print(dirs)

['RajasthanTreesHumidSoutheastern1', 'RajasthanTreesGanganagar2', 'RajasthanTreesJaisalmer2', 'RajasthanTreesTransitionalInlandDrainage4', 'RajasthanTreesTransitionalInlandDrainage3', 'RajasthanTreesFloodProneEastern1', 'RajasthanTreesSubhumidSouthernPlains1', 'RajasthanTreesTransitionalInlandDrainage2', 'RajasthanTreesJodhpur2', 'RajasthanTreesHumidSoutheastern2', 'RajasthanTreesUdaipur', 'RajasthanTreesFloodProneEastern2', 'RajasthanTreesGanganagar1', 'RajasthanTreesJaisalmer1', 'RajasthanTreesSubhumidSouthernPlains2', 'RajasthanTreesTransitionalInlandDrainage1', 'RajasthanTreesJodhpur1']


In [61]:
dirs = os.listdir('../Treework/filtered_imported_data/')
dirs = [dir for dir in dirs if dir.count('RajasthanTrees') > 0]

dir_to_zone = {
 'RajasthanTreesHumidSoutheastern1': 'SOUTH EASTERN HUMID PLAIN', 
 'RajasthanTreesGanganagar2': 'IRRIGATED NORTH WESTERN PLAIN', 
 'RajasthanTreesJaisalmer2': 'ARID WESTERN PLAIN AND HYPER ARID PARTIAL IRRIGATED', 
 'RajasthanTreesTransitionalInlandDrainage4': 'TRANSITIONAL PLAIN ZONE OF ISLAND DRAINAGE', 
 'RajasthanTreesTransitionalInlandDrainage3': 'TRANSITIONAL PLAIN ZONE OF ISLAND DRAINAGE', 
 'RajasthanTreesFloodProneEastern1': 'FLOOD PRONE EASTERN PLAIN', 
 'RajasthanTreesSubhumidSouthernPlains1': 'SUB HUMID SOUTHERN PLAIN AND ALLUVIAL HILL', 
 'RajasthanTreesTransitionalInlandDrainage2': 'TRANSITIONAL PLAIN ZONE OF ISLAND DRAINAGE', 
 'RajasthanTreesJodhpur2': 'ARID WESTERN PLAIN AND HYPER ARID PARTIAL IRRIGATED', 
 'RajasthanTreesHumidSoutheastern2': 'SOUTH EASTERN HUMID PLAIN', 
 'RajasthanTreesUdaipur': 'SOUTHERN HUMID PLAIN', 
 'RajasthanTreesFloodProneEastern2': 'FLOOD PRONE EASTERN PLAIN', 
 'RajasthanTreesGanganagar1': 'IRRIGATED NORTH WESTERN PLAIN', 
 'RajasthanTreesJaisalmer1': 'ARID WESTERN PLAIN AND HYPER ARID PARTIAL IRRIGATED', 
 'RajasthanTreesSubhumidSouthernPlains2': 'SUB HUMID SOUTHERN PLAIN AND ALLUVIAL HILL', 
 'RajasthanTreesTransitionalInlandDrainage1': 'TRANSITIONAL PLAIN ZONE OF ISLAND DRAINAGE', 
 'RajasthanTreesJodhpur1': 'ARID WESTERN PLAIN AND HYPER ARID PARTIAL IRRIGATED'
}

by_zone = {}
by_zone['TRANSITIONAL PLAIN ZONE OF LUNI BASIN'] = 0
by_zone['SEMI ARID EASTERN PLAIN'] = 0
for dir in dirs:
    count = len(os.listdir(f'../Treework/filtered_imported_data/{dir}/images/default'))
    zone = dir_to_zone[dir]
    if zone not in by_zone:
        by_zone[zone] = 0

    by_zone[zone] += count

print(len(by_zone))
print(by_zone)

9
{'TRANSITIONAL PLAIN ZONE OF LUNI BASIN': 0, 'SEMI ARID EASTERN PLAIN': 0, 'SOUTH EASTERN HUMID PLAIN': 220, 'IRRIGATED NORTH WESTERN PLAIN': 295, 'ARID WESTERN PLAIN AND HYPER ARID PARTIAL IRRIGATED': 2076, 'TRANSITIONAL PLAIN ZONE OF ISLAND DRAINAGE': 567, 'FLOOD PRONE EASTERN PLAIN': 311, 'SUB HUMID SOUTHERN PLAIN AND ALLUVIAL HILL': 637, 'SOUTHERN HUMID PLAIN': 91}
