In [223]:
import os
import boto3
import rioxarray
import rasterio
from tqdm import tqdm
import pandas as pd
import random
import numpy as np
from scipy.ndimage import label
from PIL import Image
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path
import glob
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
AWS_ACCESS_KEY_ID = "SCPFF8IH374PYXFUN6SBBNT2"
AWS_SECRET_ACCESS_KEY = "d6g1PEP4DRVTcsLipCqT1wPRVcbjAnwg1MAAzRi7UBb7SRNF0LoLQCV6Xd7a0beK"

In [196]:
data_dir = Path('/app/dataset/')

In [187]:
def get_keys(bucket_name, prefix):
    keys = []
    paginator = s3.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):
        for obj in page.get('Contents', []):
            keys.append(obj['Key'])
            
    return keys

def check_tif_values(label_path):
    with rasterio.open(label_path) as src:
        data = src.read(1)  # Reading the first band

        # Check if the values are outside the range 0, 1, 2, 3
        if ((data == 0) | (data == 1) | (data == 2) | (data == 3)).all():
            return True
        else:
            return False
        
class KenyaData(Dataset):
    
    def __init__(self, 
                 paths):
        self.tif_paths = paths
    
    def __len__(self):
        return len(self.tif_paths)

    def __getitem__(self, index):
        with rasterio.open(self.tif_paths[index]) as src:
            image = src.read()

        image = torch.tensor(image, dtype=torch.float32)
        return image

In [197]:
session = boto3.session.Session(aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
s3 = session.client('s3', endpoint_url=endpoint_url)

endpoint_url = 'https://data.source.coop'
bucket_name = "ksa"
prefix = "kenol-section"

image_dir = data_dir / 'images'
label_dir = data_dir / 'labels'

data_dir.mkdir(exist_ok=True)
image_dir.mkdir(exist_ok=True)
label_dir.mkdir(exist_ok=True)

images = [chip.name for chip in list(image_dir.glob('*.tif'))]
labels = [chip.name for chip in list(label_dir.glob('*.tif'))]

print(f'{len(images)} images detected')
print(f'{len(labels)} labels detected')

keys = get_keys(bucket_name, prefix)

labelkeys = [key.split("/")[-1] for key in keys if key.startswith("kenol-section/Labels")]
imagekeys = [key.split("/")[-1] for key in keys if key.startswith("kenol-section/Images")]

missing_images = list(set(imagekeys) - set(images))
missing_labels = list(set(labelkeys) - set(labels))

print(f'{len(missing_images)} images missing from local data will be downloaded')
print(f'{len(missing_labels)} labels missing from local data will be downloaded')

for chip_id in tqdm(missing_images, desc='downloading images...'):
    s3.download_file(bucket_name, "kenol-section/Images/"+chip_id, image_dir / chip_id)
    
for chip_id in tqdm(missing_labels, desc='downloading labels...'):
    s3.download_file(bucket_name, "kenol-section/Labels/"+chip_id, label_dir / chip_id)
    
images = list(image_dir.glob('*.tif'))
labels = list(label_dir.glob('*.tif'))

print(f'{len(images)} images after download operation')
print(f'{len(labels)} labels after download operation')

2074 images detected
2074 labels detected
0 images missing from local data will be downloaded
0 labels missing from local data will be downloaded


downloading images...: 0it [00:00, ?it/s]
downloading labels...: 0it [00:00, ?it/s]

2074 images after download operation
2074 labels after download operation





In [198]:
harmonized_label_dir = data_dir / 'harmonized_labels'
harmonized_label_dir.mkdir(exist_ok=True)

for label_path in tqdm(labels, desc="combining building classes into harmonized labels..."):

    with rasterio.open(label) as src:
        data = src.read(1)  # Read the first band

        data[data == 3] = 1

        output_filepath = harmonized_label_dir / label_path.name

        with rasterio.open(
            output_filepath,
            'w',
            driver='GTiff',
            height=data.shape[0],
            width=data.shape[1],
            count=1,
            dtype=data.dtype,
            crs=src.crs,
            transform=src.transform
            ) as dst:
            dst.write(data, 1)

combining building classes into harmonized labels...: 100%|██████████| 2074/2074 [00:09<00:00, 212.28it/s]


In [208]:
object_label_dir = data_dir / 'object_labels'
buildingdir = object_label_dir / 'buildings'
cropsdir = object_label_dir / 'crop_fields'

object_label_dir.mkdir(exist_ok=True)
buildingdir.mkdir(exist_ok=True)
cropsdir.mkdir(exist_ok=True)

In [221]:
for mask_file in tqdm(labels, desc='extracting individual object masks from labels...'):
    example_groundtruth = rioxarray.open_rasterio(mask_file)
    building_groundtruth = xr.where(example_groundtruth % 2 == 1, 1, 0)
    field_groundtruth = xr.where(example_groundtruth==2, 1, 0)
    
    building_masks, building_num_labels = label(building_groundtruth)
    building_object_masks = []
    for i in range(1, building_num_labels + 1):
        building_object_mask = xr.where(building_masks == i, 1, 0)
        building_object_mask_xr = xr.DataArray(
            building_object_mask, 
            coords=building_groundtruth.coords,
            dims=building_groundtruth.dims,
            attrs=building_groundtruth.attrs)
        building_object_masks.append(building_object_mask_xr)
        
    field_masks, field_num_labels = label(field_groundtruth)
    field_object_masks = []
    for i in range(1, field_num_labels + 1):
        field_object_mask = xr.where(field_masks == i, 1, 0)
        field_object_mask_xr = xr.DataArray(
            field_object_mask, 
            coords=field_groundtruth.coords,
            dims=field_groundtruth.dims,
            attrs=field_groundtruth.attrs)
        field_object_masks.append(field_object_mask_xr)

    for i, building_object_mask in enumerate(building_object_masks):
        building_object_mask.rio.to_raster(buildingdir / (mask_file.name[:-4]+'_b'+str(i)+'.tif'))

    for i, field_object_mask in enumerate(field_object_masks):
        field_object_mask.rio.to_raster(cropsdir / (mask_file.name[:-4]+'_f'+str(i)+'.tif'))

extracting individual object masks from labels...: 100%|██████████| 2074/2074 [01:15<00:00, 27.46it/s]


In [227]:
valid = [label_path for label_path in labels if check_tif_values(label_path)]
print(f'{len(labels) - len(valid)} chips will be filtered from the chip tracker due to containing invalid label integers')

df = pd.DataFrame([label_path.name for label_path in valid], columns=['chip_id'])

random.seed(42)
p = [random.random() for n in range(len(valid))]
df['p'] = p
df.loc[df['p'] < 0.7, 'usage'] = 'train'
df.loc[(df['p'] >= 0.7) & (df['p'] < 0.9), 'usage'] = 'validate'
df.loc[df['p'] >= 0.9, 'usage'] = 'test'

df = df.drop(columns=['p'])

chip_tracker_path = data_dir / 'chip_tracker.csv'

df.to_csv(chip_tracker_path, index=False)
print(f'chip tracker with training, validation, and test split saved to {chip_tracker_path}')

33 chips will be filtered from the chip tracker due to containing invalid label integers
chip tracker with training, validation, and test split saved to /app/dataset/chip_tracker.csv


In [236]:
valid_images = [image_dir / label_path.name for label_path in valid]

kenya_data = KenyaData(valid_images)

# data loader
kenya_dataloader = DataLoader(kenya_data, 
                          batch_size  = 8, 
                          shuffle     = False)

psum    = torch.tensor([0.0, 0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0, 0.0])
pmin    = torch.tensor([0.0, 0.0, 0.0, 0.0])
pmax    = torch.tensor([0.0, 0.0, 0.0, 0.0])

# loop through images
for tensor in tqdm(kenya_dataloader, desc='calculating dataset statistics...'):
    
    psum    += tensor.sum(axis = [0, 2, 3])
    psum_sq += (tensor ** 2).sum(axis = [0, 2, 3])
    tensor_min = torch.amin(tensor, (0, 2, 3))
    tensor_max = torch.amax(tensor, (0, 2, 3))
    pmin = torch.min(tensor_min, pmin)
    pmax = torch.max(tensor_max, pmax)
    
count = len(kenya_data) * 512 * 512
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

bands = ['red', 'green', 'blue', 'alpha']

stats_dict = {'bands' : bands,
              'mean': total_mean.numpy(),
              'std': total_std.numpy(),
              'min': pmin.numpy(),
              'max': pmax.numpy()
             }
stats_df = pd.DataFrame(stats_dict)

stats_path = data_dir / 'dataset_stats.csv'

stats_df.to_csv(stats_path, index=False)
print(f'dataset statistics saved to {stats_path}')

calculating dataset statistics...: 100%|██████████| 256/256 [00:09<00:00, 27.67it/s]

dataset statistics saved to /app/dataset/dataset_stats.csv





In [None]:
# ToDo 12/28: Use samgeo to get process all the chips in the ways that we need to
# this will conclude the data processing pipeline
# and we can create a python script and test it.