In [19]:
import os
import tempfile
from urllib.parse import urlparse

import geopandas as gpd
import matplotlib.pyplot as plt
# import planetary_computer
# import pystac
import torch
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler

from rasterio.transform import from_bounds

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 12)

In [20]:
class Skysat(RasterDataset):
    filename_glob = "*.tif"
    filename_regex = r"()"
    date_format = "%Y%m%d"
    is_image = True
    separate_files = True
    all_bands = ["red", "green", "blue"]
    rgb_bands = ["red", "green", "blue"]
    
    def __getitem__(self, index):
        sample_dict = super().__getitem__(index)

        # Convert image to tensor here, so that we can handle the bands
        # Compute the affine transform for this image
        image = sample_dict['image']

        transform = self.get_transform(sample_dict["bounds"], image.shape[1], image.shape[2])
        sample_dict["transform"] = transform
        return sample_dict

    def get_transform(self, bounds, height, width):
        # Convert a BoundingBox in pixel coordinates to an affine transform
        left, top, right, bottom = bounds.minx, bounds.maxy, bounds.maxx, bounds.miny
        return from_bounds(left, bottom, right, top, width, height)
    
    def plot(self, sample):
#         print(sample)
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample["image"][rgb_indices].permute(1, 2, 0).numpy().astype(int)
#         image = torch.clamp(image / 10000, min=0, max=1).numpy()
#         print(image)
        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig

In [21]:
def get_region(shape, key, point):
    for row_id, row in shape.iterrows():
        if row['geometry'].intersects(point):
            return row[key]
        
    return 'unknown'

In [22]:
# dataset = Skysat('../SpeciesMapping/DATA/karnataka_north_east_dry/')
dataset = Skysat('./NUTMGS/tifs_for_labeling/karnataka_top_10_tifs/')
print(dataset)

shp_file = './karnataka_footprints/DISTRICT_BOUNDARY.shp'
data = gpd.read_file(shp_file).to_crs(32643)
data = data[data['STATE'] == 'KARN>TAKA']

Skysat Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=433875.0, maxx=875573.6321397871, miny=1345080.5, maxy=1729841.5, mint=0.0, maxt=9.223372036854776e+18)
    size: 194


In [23]:
from PIL import Image
import imagehash
import torch
import random
import string
import rasterio
from shapely import Point
from torchvision.transforms.functional import to_pil_image

torch.manual_seed(3)
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
hashes = {}
for i, batch in enumerate(dataloader):
    unbinded = unbind_samples(batch)
    sample = unbinded[0]
    im = sample['image']

    numpixels = im.shape[0] * im.shape[1] * im.shape[2]
    numzero = (im == 0).sum().item()
    ratio = numzero / numpixels
    if ratio < 0.2:
        bds = sample['bounds']
        pt = Point((bds[0] + bds[1]) / 2, (bds[2] + bds[3]) / 2)
        zone = get_region(data, 'District', pt)

        image = im[0:3].permute(1, 2, 0).numpy().astype('uint8')
        hash = imagehash.dhash(Image.fromarray(image))
        hashes[str(hash)] = zone
        
print(hashes)

{'060303a1b8d8667e': 'SHIVAMOGGA', '2924d25e393d3635': 'SHIVAMOGGA', '7952f164684a8686': 'SHIVAMOGGA', 'd3b22b979557338a': 'SHIVAMOGGA', '74d1609bbb67b3b2': 'SHIVAMOGGA', '3031b2a2880060e3': 'SHIVAMOGGA', 'dbc977db793ecbb6': 'SHIVAMOGGA', '1911101c0c4ccdcc': 'SHIVAMOGGA', '08283839498c3511': 'SHIVAMOGGA', '185ba34124904031': 'SHIVAMOGGA', '9d277737dede9c3d': 'SHIVAMOGGA', '24468981d3a382ef': 'SHIVAMOGGA', '86711404203e6386': 'SHIVAMOGGA', 'e2240f4ec62604c4': 'SHIVAMOGGA', '79785028eccecbdf': 'SHIVAMOGGA', '3c3e9b1f8e3a1333': 'SHIVAMOGGA', 'ccb430303276783c': 'SHIVAMOGGA', '888c122c2c9298d0': 'SHIVAMOGGA', 'd2a1259c981b8fdf': 'SHIVAMOGGA', '70334b0980482424': 'SHIVAMOGGA', 'aeaebb256faf1968': 'SHIVAMOGGA', 'db1a53fc6ea795c9': 'SHIVAMOGGA', '22c495282c4e46b1': 'SHIVAMOGGA', '193c0cd072232202': 'SHIVAMOGGA', 'cd493b9b81c9cbea': 'SHIVAMOGGA', '3f351c7c7078fc70': 'SHIVAMOGGA', '58210913d3d3b278': 'SHIVAMOGGA', '897132a2606060d1': 'SHIVAMOGGA', '1c246360e4c4c3c3': 'SHIVAMOGGA', '7353d31a0ec8

In [24]:
def find_closest(im_hash, hashes):
    min_dist = 1000
    best = None
    for hash in hashes:
        h = imagehash.hex_to_hash(hash)
        if h - im_hash < min_dist:
            min_dist = h - im_hash
            best = hash

    return best, min_dist

In [27]:
import cv2

total = 0
not_found = 0
unknown = 0

f = open('karnataka_image_zones.csv', 'w')
f.write('filename,zone\n')

district_to_zone = {
    'SHIVAMOGGA': 'SOUTHERN TRANSITION', 
    'UTTARA  KANNADA': 'HILL', 
    'D>VANGERE': 'CENTRAL DRY', 
    'CHITRADURGA': 'CENTRAL DRY', 
    'BALL>RI': 'NORTH EAST DRY', 
    'DH>RWAD': 'WESTERN TRANSITION', 
    'GADAG': 'NORTHERN DRY', 
    'KOPPAL': 'NORTH EAST DRY', 
    'R>ICH@R': 'NORTH EAST DRY', 
    'Y>DG|R': 'NORTH EAST DRY',
    'KODAGU': 'SOUTHERN DRY', 
    'MANDYA': 'SOUTHERN DRY', 
    'R>MANAGARAM': 'EASTERN DRY', 
    'DAKSHINA  KANNADA': 'COASTAL', 
    'HASSAN': 'SOUTHERN TRANSITION', 
    'KOLAR': 'EASTERN DRY', 
    'BENGAL@RU RURAL': 'EASTERN DRY', 
    'UDUPI': 'COASTAL', 
    'TUMAK@RU': 'CENTRAL DRY', 
    'CHIK BALL>PUR': 'EASTERN DRY'
}

by_zone = {}
for i in range(8):
    img_src = f'/Users/Hugo/Coding_Projects/Sidd Lab/Treework/filtered_imported_data/KarnatakaTrees{i}/images/default/'
    if i == 7:
        img_src = f'/Users/Hugo/Coding_Projects/Sidd Lab/Treework/filtered_imported_data/KarnatakaTrees7Partial/images/default/'

    for idx, file in enumerate(os.listdir(img_src)):
        im = cv2.imread(os.path.join(img_src, file))[:,:,::-1]

        hash = imagehash.dhash(Image.fromarray(im))
        total += 1
        if str(hash) not in hashes:
            best, dist = find_closest(hash, hashes)
            if dist > 12:
                f.write(f'{file},UNKNOWN\n')
                by_zone['UNKNOWN'] = by_zone.get('UNKNOWN', 0)+1
                unknown += 1
            else:
                f.write(f'{file},{district_to_zone[hashes[best]]}\n')
                by_zone[district_to_zone[hashes[best]]] = by_zone.get(district_to_zone[hashes[best]], 0)+1
            
            not_found += 1
        else:
            f.write(f'{file},{district_to_zone[hashes[str(hash)]]}\n')
            by_zone[district_to_zone[hashes[str(hash)]]] = by_zone.get(district_to_zone[hashes[str(hash)]], 0)+1

f.close()
print(f'Total images: {total}')
print(f'Without exact match: {not_found}')
print(f'Without close match: {unknown}')
print(by_zone)

Total images: 8109
Without exact match: 1320
Without close match: 398
{'EASTERN DRY': 1157, 'CENTRAL DRY': 1649, 'NORTHERN DRY': 589, 'SOUTHERN TRANSITION': 1142, 'WESTERN TRANSITION': 313, 'COASTAL': 921, 'HILL': 181, 'NORTH EAST DRY': 1290, 'UNKNOWN': 398, 'SOUTHERN DRY': 469}


In [112]:
dataset = Skysat('./tifs_done/')
print(dataset)

Skysat Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=-124.13750190128431, maxx=-114.19532113749236, miny=32.80281803283856, maxy=41.93959626234747, mint=0.0, maxt=9.223372036854776e+18)
    size: 89


In [90]:
cali = './ca_state/CA_State.shp'
cali_borders = gpd.read_file(cali).to_crs(4269)

shp_file = './phzm_us_zones_shp_2023/phzm_us_zones_shp_2023.shp'
data = gpd.read_file(shp_file)

intersection = data.overlay(cali_borders).to_crs(4326)

Skysat Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=-123.41130382560209, maxx=-116.0024844524377, miny=33.80101699339933, maxy=41.59384369264271, mint=0.0, maxt=9.223372036854776e+18)
    size: 10


In [113]:
from PIL import Image
import imagehash
import torch
import random
import string
import rasterio
from shapely import Point
from torchvision.transforms.functional import to_pil_image

torch.manual_seed(3)
sampler = GridGeoSampler(dataset, size=400, stride=400)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
cali_hashes = {}
for i, batch in enumerate(dataloader):
    unbinded = unbind_samples(batch)
    sample = unbinded[0]
    im = sample['image']

    numpixels = im.shape[0] * im.shape[1] * im.shape[2]
    numzero = (im == 0).sum().item()
    ratio = numzero / numpixels
    if ratio < 0.2:
        bds = sample['bounds']
        pt = Point((bds[0] + bds[1]) / 2, (bds[2] + bds[3]) / 2)
        zone = get_region(data, 'zone', pt)

        image = im[0:3].permute(1, 2, 0).numpy().astype('uint8')
        hash = imagehash.dhash(Image.fromarray(image))
        cali_hashes[str(hash)] = zone
        
print(cali_hashes)

  image = im[0:3].permute(1, 2, 0).numpy().astype('uint8')


{'63703430b4722261': '10a', 'f39110901257d356': '10a', 'b34c4c968c8c5049': '10a', '8b3030b430309692': '10a', '4db8e068ec9ce0f0': '10a', '37d6e755491c5647': '10a', 'e42b0ba6e6e6e626': '10a', 'c362a44b00044050': '10a', 'dccd4d4d5f7535b1': '10a', 'ab2b1915313c199b': '10a', '35b4d12c1c4cd9d9': '10a', 'ad8db13236366c99': '10a', '4b8e8a4541f1f2f2': '10a', '98546662369ade29': '10a', 'dc7870b4b4bcd95a': '10a', '6161a617a4ccc806': '10a', '4244994e464636b2': '10a', '919233b3b2b667c9': '10a', 'dc5d3991e20cb257': '10a', '070b495d59a4a48b': '10a', '2626d7891eaa2960': '10a', '037823e664884872': '10a', '121284111216644b': '10a', '8aca68b949c8888f': '10a', '5492f25466b4b0b8': '10a', '9a9a98cbd4d8d8d8': '10a', 'c0c8b2bc9cb0eeb1': '10a', 'a6a6566d61676566': '10a', 'c149789a8a3a92b0': '10a', '6d059891b282a191': '10a', '9a222a09b6064c99': '10a', '2d17964ee4ed6667': '10a', '906c65ba18945a2d': '10a', 'dbb33839216122dd': '10a', '59393b226a8fa75a': '10a', '325201e383cbba51': '10a', 'c4c32f2c2c2c749c': '10a', 

In [109]:
from tifffile import imread
src_folder = './tifs_done/'
crop_hashes = {}
for file in os.listdir(src_folder):
    im = imread(os.path.join(src_folder, file)).astype('uint8')
    hash = imagehash.crop_resistant_hash(Image.fromarray(im))
    crop_hashes[str(hash)] = 'test'

print(crop_hashes)

  im = imread(os.path.join(src_folder, file)).astype('uint8')


{'da7c38a91c8d5c5a': 'test', '8c78783e8ee32726,024342422669683c,6361d0da2c8e6636,01c19108b4d8e5e5,d8ccb81ee3371fb1': 'test', '63138ecc25258c3e': 'test', '372d24948e0f8d44,2e4ca62e0cac810c,dc3c1c78c6922081,8c2c78f5caf83080,7cdc6ae2d218e4ac,d8f9d4969af2e293,6060d8988ccfc9e5,4c65656d8ca4060c,0327333634e1e3c7': 'test', '3431ecf41b074342': 'test', 'd4d4ddd7c6c64339,39c8c8c8cccc8c8c,2624b232684a6666,89a484a0aaecec2c,d7e7f57535353567,9c118b0b666666e8': 'test', 'c252632b92318106': 'test', '349ccbc1a29199ee,361bad31d9d8d4de,7d361b49a593590c,04482292c866390c,ee775b0c8643b1b0,ce7f3e3799991c06,c671384eb7d15fff,fc7e2e96cf63a148,685c7b3d8c879308,314019374e9d3a74,32194c269fcb6198,91d1c0683018cea3': 'test', 'c8cc1c9c9bacec5c,38a4e56012db9f8b,26a58483b874a484,e86e6e6afa84a484,b0b4629282a6a008,24c50a4c20d8ed6c': 'test', '0e1c28c1e2004913,e7a5d6dac1c9c8cc,d8d8a6bcb490c0d2,c1c7836f2c3e7169,2441498397061531,4161c74363cc1e23': 'test'}


In [116]:
import cv2

total = 0
not_found = 0
unknown = 0

f = open('california_image_zones.csv', 'w')
f.write('filename,zone\n')

by_zone = {}
for i in range(9):
    img_src = f'/Users/Hugo/Coding_Projects/Sidd Lab/Treework/filtered_imported_data/CaliforniaTrees{i}/images/default/'

    for idx, file in enumerate(os.listdir(img_src)):
        im = cv2.imread(os.path.join(img_src, file))[:,:,::-1]

        hash = imagehash.dhash(Image.fromarray(im))
        total += 1
        if str(hash) not in cali_hashes:
            best, dist = find_closest(hash, cali_hashes)
            if dist > 12:
                f.write(f'{file},UNKNOWN\n')
                by_zone['UNKNOWN'] = by_zone.get('UNKNOWN', 0)+1
                unknown += 1
            else:
                f.write(f'{file},{cali_hashes[best]}\n')
                by_zone[cali_hashes[best]] = by_zone.get(cali_hashes[best], 0)+1
            
            not_found += 1
        else:
            f.write(f'{file},{cali_hashes[str(hash)]}\n')
            by_zone[cali_hashes[str(hash)]] = by_zone.get(cali_hashes[str(hash)], 0)+1

f.close()
print(f'Total images: {total}')
print(f'Without exact match: {not_found}')
print(f'Without close match: {unknown}')
print(by_zone)

Total images: 10841
Without exact match: 121
Without close match: 118
{'UNKNOWN': 118, '7a': 1204, '6b': 1079, '8b': 1038, '6a': 1083, '10a': 1153, '10b': 885, '9b': 1122, '8a': 1065, '9a': 1080, '7b': 970, '11a': 20, '5b': 24}
