In [None]:
import xarray as xr
import numpy as np
import rioxarray
from pathlib import Path
from matplotlib import pyplot as plt
from geocoded_object_extractor.utils import hash_classname

In [None]:
# Set a mat for class code and name of the class
description_labels = {0: 'netflora_murumuru_embrapa00',
                      1: 'netflora_buriti_emprapa00', 
                      2: 'netflora_tucuma_emprapa00',
                      3: 'reforestree_banana',
                      4: 'reforestree_cacao',
                      5: 'reforestree_fruit',}

## Uniform the netflora cuout sizes

In [None]:
dir_netflora = Path('./output_cutouts')

In [None]:
def get_padsize(taget_size, img_size_x, img_size_y):
    xleft = (taget_size - img_size_x)//2
    xright = taget_size - xleft - img_size_x
    ytop = (taget_size - img_size_y)//2
    ybottom = taget_size - ytop - img_size_y

    return xleft, xright, ytop, ybottom

In [None]:
# load the imgs as xarray, padding zeros to make them square, then save each class as a zarr file
target_size = 128

# for label_dir, label_idx in zip(['label_0', 'label_2', 'label_6'], range(3)):
for label_dir, label_idx in zip(['label_0'], range(1)):
    dir_class= dir_netflora / label_dir
    list_imgs = list(dir_class.glob('*.jpg'))

    # load the first image as a template
    imgs = rioxarray.open_rasterio(list_imgs[0])

    # pad the first image to make it square
    xleft, xright, ytop, ybottom = get_padsize(target_size, imgs.sizes['x'], imgs.sizes['y'])
    imgs = imgs.pad(pad_width={'x': (xleft,xright), 'y': (ytop,ybottom)}, mode='constant', constant_values=0)
    imgs = imgs.expand_dims('sample', axis=0)
    imgs.drop_indexes(['band', 'x', 'y'])

    # loop and pad the rest of the images
    for f_img in list_imgs[1:]:
        img = rioxarray.open_rasterio(f_img)

        # if the image is larger than target size, skip it
        if img.sizes['x']>target_size or imgs.sizes['y']>target_size:
            continue

        xleft, xright, ytop, ybottom = get_padsize(target_size, img.sizes['x'], img.sizes['y'])
        img_pad = img.pad(pad_width={'x': (xleft,xright), 'y': (ytop,ybottom)}, mode='constant', constant_values=0)
        img_pad = img_pad.expand_dims('sample', axis=0)
        img_pad = img_pad.drop_indexes(['band', 'x', 'y'])
        imgs = xr.concat([imgs, img_pad], dim='sample')

    # make a dataset with both images and labels
    ds = xr.Dataset({'X': imgs, 'Y': xr.DataArray([label_idx]*len(list_imgs), dims='sample', name='label')})
    ds = ds.drop_vars(['band', 'x', 'y', 'spatial_ref'])
    ds = ds.transpose('sample', 'x', 'y', 'band')
    ds = ds.rename({'band': 'channel'})
    ds = ds.chunk('auto')

    ds.to_zarr(f'./label{label_idx}_{description_labels[label_idx]}.zarr', mode='w')

In [None]:
ds_label1 = xr.open_zarr('./label0_murumuru_embrapa00.zarr')
np.unique(ds_label1['Y']) # check the labels, should be 1

## Down sampling reforesttree 
Down sampling reforesttree dataset from 1cm to 6cm 

In [None]:
# Inspect the resolution of the reforestree data
data = rioxarray.open_rasterio('/home/oku/Developments/XAI4GEO/data/reforestree/wwf_ecuador/RGB Orthomosaics/Carlos Vera Arteaga RGB.tif')
data_subset = data.isel(x=range(8000, 9000), y=range(8000, 9000)) # Take a slice to reproject, since the reprojection is computationally expensive
data_subset_meter = data_subset.rio.reproject('EPSG:32717') # Reproject to UTM Zone 17S
data_subset_meter.rio.resolution() # This gives the resolution of ~ 1cm

In [None]:
# Inspect the resolution of the brazil data
data = rioxarray.open_rasterio('/home/oku/Developments/XAI4GEO/data/brazil_data/original_data/PNM/PROCESSADOS/Map1_Orthomosaic_export_SatJun10172428194829.tif')
data.rio.resolution() # This gives the resolution of ~ 6cm

In [None]:
# Load the reforestree data with larger than 384x384 pixels
f_reforestree =  Path('/home/oku/Developments/XAI4GEO/data/reforestree/processed/larger_than_384/foresttree_largerthan_384.zarr')
# f_reforestree =  Path('/home/oku/Developments/XAI4GEO/data/reforestree/processed/larger_than_200/foresttree_largerthan_200.zarr')

In [None]:
data_reforestree = xr.open_zarr(f_reforestree)
data_reforestree

In [None]:
# Down sample the reforestree data by a (6, 6) window
data_reforestree_downsampled = data_reforestree.coarsen(x=6, y=6, boundary='trim').mean()
data_reforestree_downsampled

## Resizing the image

In [None]:
target_size = 128
pad_size = (target_size - data_reforestree_downsampled.sizes['x'])//2
data_reforestree_downsampled_padded = data_reforestree_downsampled.pad(mode='constant', pad_width={'x': (pad_size,pad_size), 'y': (pad_size,pad_size)}, constant_values=0)
data_reforestree_downsampled_padded

In [None]:
labels = data_reforestree_downsampled["Y"].values

In [None]:
np.unique(labels)

In [None]:
# Bananas
ds_lebel1 = data_reforestree_downsampled_padded.where(data_reforestree_downsampled_padded["Y"].compute()==1, drop=True)
ds_lebel1

In [None]:
# Cacao
ds_lebel2 = data_reforestree_downsampled_padded.where(data_reforestree_downsampled_padded["Y"].compute()==2, drop=True)
ds_lebel2

In [None]:
# radomly select 104 samples
np.random.seed(0)
rng = np.random.default_rng()
idx = rng.choice(range(len(ds_lebel1['sample'])), 104, replace=False)
ds_label2_sel = ds_lebel1.isel(sample=idx)
ds_label2_end = xr.open_zarr('/home/oku/Developments/XAI4GEO/data/cleaned_data/selected_cutouts/label274675723264_banana.zarr')
idx_insert = rng.choice(range(len(ds_label2_sel['sample'])), ds_label2_end.sizes['sample'], replace=False)
idx_insert = np.sort(idx_insert)
print(idx_insert)
# Update ds_label2_sel at locations idx_insert
ds_label2_sel_update = ds_label2_sel.copy()
ds_label2_sel_update['X'][idx_insert, :, :, :] = xr.DataArray(ds_label2_end['X'].values, dims=['sample', 'x', 'y', 'channel'])
ds_label2_sel_update = ds_label2_sel_update.chunk({"sample": 10, "y": -1, "x": -1, "channel": -1})
ds_label2_sel_update.to_zarr('./label274675723264_banana.zarr', mode='w')

In [None]:
# Fruit
ds_lebel4 = data_reforestree_downsampled_padded.where(data_reforestree_downsampled_padded["Y"].compute()==4, drop=True)
ds_lebel4

In [None]:
for ds, name in zip([ds_lebel1, ds_lebel2, ds_lebel4], ['banana', 'cacao', 'fruit']):
    ds['Y'] = xr.DataArray(np.array([hash_classname(f'{name}')]*ds.sizes["sample"]), dims='sample')
    ds = ds.chunk({'sample': 50, 'x': 128, 'y': 128, 'channel': 3})
    ds.to_zarr(f'./label{ds["Y"].values[0]}_{name}.zarr', mode='w')

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds_lebel1.isel(sample=rng.choice(ds_lebel1.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds_lebel2.isel(sample=rng.choice(ds_lebel2.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
# Randomly select 50 samples and plot
rng = np.random.default_rng()
ds_plot = ds_lebel4.isel(sample=rng.choice(ds_lebel4.sizes['sample'], 50, replace=False))
ds_plot['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
# # Save the downsampled data to separated zarr files per class
# labels = data_reforestree_downsampled["Y"].compute()
# for label in np.unique(labels.values):
#     species_name = data_reforestree_downsampled.attrs[label.astype(str)]['ESPECIE']
#     ds = data_reforestree_downsampled.where(labels==label,drop=True)
#     ds = ds.chunk('auto')
#     print(f'{species_name}: {ds.sizes["sample"]}')
#     ds.to_zarr(f'./label{label}_{species_name}.zarr', mode='w')

In [None]:
# np.unique(labels.values)

In [None]:
# # Plot some random samples per class
# n_samples_plot = 5
# rng = np.random.default_rng()
# fig, axs = plt.subplots(len(np.unique(labels.values)), n_samples_plot, figsize=(15, 15))
# ax_row=0
# labels = data_reforestree_downsampled["Y"].compute()
# for label in np.unique(labels.values): # Loop through classes
#     ds = xr.open_zarr(f'./label{label}_{data_reforestree_downsampled.attrs[label.astype(str)]["ESPECIE"]}.zarr')
#     idx = rng.integers(0, ds.sizes["sample"] , size=n_samples_plot)
#     for i, ax in enumerate(axs[ax_row]):
#         ax.imshow(ds['X'][idx[i]].values/255.)
#         ax.axis('off')
#         ax.set_title(f'{data_reforestree_downsampled.attrs[label.astype(str)]["ESPECIE"]}')
#     ax_row+=1

## Change the label value of the old data (only performed once for the old data)

In [None]:
# dir_cutouts = Path('/home/oku/Developments/XAI4GEO/data/cleaned_data/all_cutouts')

# list_zarr = ['label142377591163_murumuru.zarr/', 'label244751236943_tucuma.zarr/', ]
# for file in list_zarr:
#     data = xr.open_zarr(dir_cutouts/file)
#     print(file)
#     print(f"shape:{data['X'].sizes}")
#     print(f"label:{np.unique(data['Y'].values)}")
#     label = float(file[5:17])
#     data['Y'] = xr.DataArray(np.array([label]*data.sizes['sample']), dims='sample')
#     print(f"UPDATED label:{np.unique(data['Y'].values)}")
#     print("---")
#     data = data.chunk('auto')
#     data.to_zarr(Path('.')/file, mode='w')

In [None]:
# # Check the save zarr files
# dir_cutouts = Path('.')
# list_zarr = ['label142377591163_murumuru.zarr/', 'label244751236943_tucuma.zarr/', ]
# for file in list_zarr:
#     data = xr.open_zarr(dir_cutouts/file)
#     print(file)
#     print(f"shape:{data['X'].sizes}")
#     print(f"label:{np.unique(data['Y'].values)}")