In [7]:
import os
import tempfile
from urllib.parse import urlparse
import geopandas as gpd
from shapely import Point

import matplotlib.pyplot as plt
# import planetary_computer
# import pystac
import torch
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler

from rasterio.transform import from_bounds

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 12)

## Input

#### in this we will be using the shapefile that we created from task11_skysat_scene_filtering of the chosen scenes that are not overlapping

In [8]:
class Skysat(RasterDataset):
    filename_glob = "*.tif"
    filename_regex = r"()"
    date_format = "%Y%m%d"
    is_image = True
    separate_files = True
    all_bands = ["red", "green", "blue"]
    rgb_bands = ["red", "green", "blue"]
    
    def __getitem__(self, index):
        sample_dict = super().__getitem__(index)

        # Convert image to tensor here, so that we can handle the bands
        # Compute the affine transform for this image
        image = sample_dict['image']

        transform = self.get_transform(sample_dict["bounds"], image.shape[1], image.shape[2])
        sample_dict["transform"] = transform
        return sample_dict

    def get_transform(self, bounds, height, width):
        # Convert a BoundingBox in pixel coordinates to an affine transform
        left, top, right, bottom = bounds.minx, bounds.maxy, bounds.maxx, bounds.miny
        return from_bounds(left, bottom, right, top, width, height)
    
    def plot(self, sample):
#         print(sample)
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample["image"][rgb_indices].permute(1, 2, 0).numpy().astype(int)
#         image = torch.clamp(image / 10000, min=0, max=1).numpy()
#         print(image)
        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig

In [3]:
!ls /Users/siddharthsachdeva/repos/agroforestry/notebooks/exploration/karnataka_clearannotationset1/030b0be1-a128-4a58-805d-def84e5dab92/

[34mSkySatScene[m[m   manifest.json


In [4]:
clear_imgs = ['20220418_050008_ssc2d2_0061_visual',
 '20220415_083158_ssc9d2_0002_visual',
 '20220418_051036_ssc13d1_0004_visual',
 '20220416_065524_ssc19d1_0054_visual',
 '20220427_053104_ss01d1_0022_visual',
 '20220429_053415_ss01d3_0007_visual',
 '20220424_052803_ss01d2_0009_visual',
 '20220411_084414_ssc8d3_0013_visual',
 '20220428_051006_ssc2d2_0012_visual',
 '20220428_051006_ssc2d2_0012_visual',
 '20220429_083200_ssc10d2_0006_visual',
 '20220417_051528_ssc12d1_0025_visual',
 '20220427_083312_ssc7d3_0010_visual',
 '20220416_083956_ssc7d3_0027_visual',
 '20220422_084107_ssc8d3_0014_visual',
 '20220425_052817_ss01d2_0026_visual'
]
len(clear_imgs)

16

In [21]:
# dataset = Skysat('/Users/siddharthsachdeva/repos/agroforestry/notebooks/exploration/data/test_quality/images/Punjab_testcollects/')
# dataset = Skysat('./cali/')
dataset = Skysat('../../SpeciesMapping/DATA/karnataka_north_east_dry/')
print(dataset)

Skysat Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=693271.5, maxx=699644.0, miny=1842451.0, maxy=1856252.5, mint=0.0, maxt=9.223372036854776e+18)
    size: 15


In [10]:
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
print(sampler.length)

8288


In [9]:
!mkdir /Users/siddharthsachdeva/Downloads/karnataka_skysatimages/

['1205Z2BBA3_chipid1666.tiff',
 'H6H7RH92PJ_chipid512.tiff',
 '229PJFARSE_chipid266.tiff',
 'WXDWLL549S_chipid3260.tiff',
 '7SICQVL0NP_chipid1239.tiff',
 '2HM3ML8UZJ_chipid128.tiff',
 '0Z5NLNHAU2_chipid330.tiff',
 'KTQZTERFG1_chipid519.tiff',
 '1UFVFKT4CR_chipid414.tiff',
 'IB2T8IDDHD_chipid1429.tiff',
 'V2B9QPUDIP_chipid1374.tiff',
 'XJMGY1GQT3_chipid953.tiff',
 '5Q0M6QYLDZ_chipid342.tiff',
 'Q4QT07BX9A_chipid783.tiff',
 'GX62B66UK5_chipid2925.tiff',
 'BK7BTRFSBS_chipid615.tiff',
 '6EGZF4IQUT_chipid2575.tiff',
 '44O3R4PQGJ_chipid408.tiff',
 'KICM5PACNK_chipid2128.tiff',
 'FA1R09V5JW_chipid3129.tiff',
 '1F4D94JSWF_chipid1982.tiff',
 'M7F4XVUPWN_chipid2178.tiff',
 '2RZZIERKKD_chipid1438.tiff',
 'R428B9PRM1_chipid1340.tiff',
 'TR7Z2K496W_chipid2899.tiff',
 '5RNX38WO6Q_chipid2489.tiff',
 'TDPCG7SKS1_chipid2977.tiff',
 '6ML4NP3Z4D_chipid2976.tiff',
 'SWH5GHI7AD_chipid2794.tiff',
 'XSZ7UUTBQQ_chipid2524.tiff',
 'RDODWN65SL_chipid957.tiff',
 'ATTGY2ST4Z_chipid2008.tiff',
 'G9TEJXQDJ0_chipid1

In [10]:
! mkdir broken

In [11]:
!mv /Users/siddharthsachdeva/repos/agroforestry/notebooks/exploration/data/filteredtiffs/Goa/3d66a1c8-d2cd-4c82-91b4-411897db3329/ broken/

In [9]:
!mkdir /Users/siddharthsachdeva/Downloads/punjab_skysatimages/

In [11]:
def get_region(shape, key, point):
    for row_id, row in shape.iterrows():
        if row['geometry'].intersects(point):
            return row[key]
        
    return 'unknown'

In [12]:
wanted_remaining = {
    'NORTH EAST TRANSITION': 800,
    'NORTH EAST DRY': 516,
    'NORTHERN DRY': 192,
    'CENTRAL DRY': 0,
    'EASTERN DRY': 0,
    'SOUTHERN DRY': 450,
    'SOUTHERN TRANSITION': 0,
    'WESTERN TRANSITION': 508,
    'HILL': 760,
    'COASTAL': 538
}

In [23]:
shp_file = '../karnataka_footprints/DISTRICT_BOUNDARY.shp'
data = gpd.read_file(shp_file).to_crs(32643)
data = data[data['STATE'] == 'KARN>TAKA']

district_to_zone = {
    'SHIVAMOGGA': 'SOUTHERN TRANSITION', 
    'UTTARA  KANNADA': 'HILL', 
    'D>VANGERE': 'CENTRAL DRY', 
    'CHITRADURGA': 'CENTRAL DRY', 
    'BALL>RI': 'NORTH EAST DRY', 
    'DH>RWAD': 'WESTERN TRANSITION', 
    'GADAG': 'NORTHERN DRY', 
    'KOPPAL': 'NORTH EAST DRY', 
    'R>ICH@R': 'NORTH EAST DRY',
    'Y>DG|R': 'NORTH EAST DRY', 
    'KODAGU': 'SOUTHERN DRY', 
    'MANDYA': 'SOUTHERN DRY', 
    'R>MANAGARAM': 'EASTERN DRY', 
    'DAKSHINA  KANNADA': 'COASTAL', 
    'HASSAN': 'SOUTHERN TRANSITION', 
    'KOLAR': 'EASTERN DRY', 
    'BENGAL@RU RURAL': 'EASTERN DRY', 
    'UDUPI': 'COASTAL', 
    'TUMAK@RU': 'CENTRAL DRY', 
    'CHIK BALL>PUR': 'EASTERN DRY',
    'unknown': 'None'
}

In [18]:
from PIL import Image
import torch
import random
import string
import rasterio

for i in range(1):
    os.makedirs(f'./extra_karnataka/', exist_ok=True)

torch.manual_seed(3)
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
c = 0
for i, batch in enumerate(dataloader):
    unbinded = unbind_samples(batch)
    sample = unbinded[0]
    im = sample['image']
#     print(sample.keys())
    rand_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
    numpixels = im.shape[0] * im.shape[1] * im.shape[2]
    numzero = (im == 0).sum().item()
    ratio = numzero / numpixels
    if ratio < 0.2:
# Reorder and rescale the image
        image = sample["image"][[0,1,2]].numpy().astype('uint8')
#         im = Image.fromarray(image)
        # bin = int(random.random() * 2)+7
        p = f'./extra_karnataka/{rand_str}_chipid{c}.tiff'
        # p = f'/Users/siddharthsachdeva/Downloads/punjab_skysatimages/{rand_str}_chipid{c}.tiff'
#         p = f'/Users/siddharthsachdeva/Downloads/karnataka_skysatimages/{rand_str}'
        # Then use im.numpy() when writing with rasterio
        with rasterio.open(
            p,
            'w',
            driver='GTiff',
            height=400,
            width=400,
            count=3,  # Assuming im is in C,H,W format
            dtype='uint8',  # Ensure the dtype matches your data
            crs=sample['crs'],
            transform=sample['transform']
        ) as dst:
            dst.write(image)  # Convert tensor to NumPy array for rasterio
#         print(p)
#         im.save(p)
        c += 1

print(c)

280


In [24]:
from PIL import Image
import torch
import random
import string
import rasterio

os.makedirs(f'./extra_karnataka/', exist_ok=True)
os.makedirs(f'./extra_karnataka/fill/', exist_ok=True)
os.makedirs(f'./extra_karnataka/overflow/', exist_ok=True)

torch.manual_seed(3)
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
c = 0
for i, batch in enumerate(dataloader):
    unbinded = unbind_samples(batch)
    sample = unbinded[0]
    im = sample['image']
#     print(sample.keys())
    rand_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
    numpixels = im.shape[0] * im.shape[1] * im.shape[2]
    numzero = (im == 0).sum().item()
    ratio = numzero / numpixels
    if ratio < 0.2:
        bds = sample['bounds']
        pt = Point((bds[0] + bds[1]) / 2, (bds[2] + bds[3]) / 2)
        zone = district_to_zone[get_region(data, 'District', pt)]
       
        image = sample["image"][[0,1,2]].numpy().astype('uint8')
       
        # fill up buckets
        if zone in wanted_remaining and wanted_remaining[zone] > 0:
            p = f'./extra_karnataka/fill/{rand_str}_chipid{c}.tiff'
            wanted_remaining[zone] -= 1
        else:
            p = f'./extra_karnataka/overflow/{rand_str}_chipid{c}.tiff'

        with rasterio.open(
            p,
            'w',
            driver='GTiff',
            height=400,
            width=400,
            count=3,  # Assuming im is in C,H,W format
            dtype='uint8',  # Ensure the dtype matches your data
            crs=sample['crs'],
            transform=sample['transform']
        ) as dst:
            dst.write(image)  # Convert tensor to NumPy array for rasterio

        c += 1

print(c)

732


In [25]:
print(wanted_remaining)

{'NORTH EAST TRANSITION': 800, 'NORTH EAST DRY': 0, 'NORTHERN DRY': 0, 'CENTRAL DRY': 0, 'EASTERN DRY': 0, 'SOUTHERN DRY': 0, 'SOUTHERN TRANSITION': 0, 'WESTERN TRANSITION': 0, 'HILL': 0, 'COASTAL': 0}


In [33]:
print(len(os.listdir('./relabeled_tifs/cali_8/')))

1219


In [47]:
print(len(os.listdir('./relabeled_tifs/rajasthan/')))

15045
