In [1]:
import geopandas as gpd
from shapely.geometry import Polygon
from shapely.ops import transform
from shapely.ops import unary_union
from rasterio.transform import from_bounds
import rasterio
from rasterio import mask

# Utils
import json
import os
import numpy as np
from tqdm import tqdm
import glob

### Open and re-project

In [20]:
tree_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data'
raster_path = os.path.join(tree_path, 'tiles_hierarchical/images/z3/z3_cropped_raster.tif')
save_path = os.path.join(tree_path,'tiles_hierarchical/labels/z3/labels_z3.tif') 

In [21]:
gdf = gpd.read_file(os.path.join(tree_path, 'geojson_polygons/Z3_polygone_2021_09_02_view.geojson'))

In [22]:
gdf.crs

<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- name: World.
- bounds: (-180.0, -90.0, 180.0, 90.0)
Datum: World Geodetic System 1984 ensemble
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

In [23]:
with rasterio.open(raster_path, "r") as src:
    raster_meta = src.meta
    
print('Raster CRS: ', raster_meta['crs'])

Raster CRS:  EPSG:32618


In [24]:
gdf_projected = gdf.to_crs(raster_meta['crs'])

In [25]:
#gdf_projected.head()
gdf_projected['Label'].value_counts()

ACRU        1201
ABBA         977
BEPA         972
THOC         344
ACPE         234
ACSA         153
Picea        142
Mort         136
FAGR         122
Acer         115
BEAL          84
POGR          70
Conifere      28
PIRU          23
PIST          21
Feuillus      14
TSCA           7
PIGL           7
PIMA           4
LALA           4
Populus        2
POTR           1
Name: Label, dtype: int64

### Create masks tiff

In [26]:
def write_dataset(mask, save_path, raster_meta):
    """  
    Inputs:
    
    mask -> Full prediction mask numpy array.
    save_path -> Path to write the mask to.
    raster_meta -> Metadate from the original raster  
    """
    
    mask = mask.astype("uint16")
    bin_mask_meta = raster_meta.copy()
    bin_mask_meta.update({'count': 1})
    with rasterio.open(save_path, 'w', **bin_mask_meta) as dst:
        dst.write(mask, 1)
#         dst.write(mask * 10, 1)

In [27]:
def write_classdict(unq_classes, write_path):
    """
    Inputs:
    unq_classes -> Array of unique classes.
    write_path -> Path to write the json file to.
    """
    
    classes = np.linspace(1, len(unq_classes), num=len(unq_classes), endpoint=True).astype('int')
    classes_id = [int(num) for num in classes]
    
    dictionary = dict(zip(unq_classes, classes_id))
    with open(write_path, 'w') as fp:
        json.dump(dictionary, fp)

In [28]:
# Remove None and empty polygons
gdf_cleaned = gdf_projected[~(gdf_projected['geometry'].is_empty | gdf_projected['geometry'].isna())]

# Number of unique classes
unq_classes = np.unique(gdf_cleaned['Label'])

# Numpy array of same size as tiff
prediction_merged = np.zeros((raster_meta['height'], raster_meta['width'])).astype('uint16')

# Counter to multiply binary mask with
cls_label = 0

for cls_id in tqdm(unq_classes):
    cls_label += 1
    
    gdf_subset = gdf_cleaned.loc[gdf_cleaned['Label'] == cls_id, 'geometry']
    
    print('The class id and the corresponding label is: ', cls_id, cls_label)
    
    positive, _, _ = rasterio.mask.raster_geometry_mask(src, gdf_subset, crop=False)
    
    # Invert the mask because rasterio gives 0 for positive values
    positive = np.logical_not(positive).astype('uint16')
    positive = positive*cls_label
    prediction_merged += positive
    
    # Clear to free up memory
    del positive

# Write things to disk
write_dataset(prediction_merged, save_path, raster_meta)
# write_classdict(unq_classes, '/home/mila/v/venkatesh.ramesh/scratch/tree_data/tiles_512/labels/z2/classes_z2.json')

  0%|          | 0/22 [00:00<?, ?it/s]

The class id and the corresponding label is:  ABBA 1


  5%|▍         | 1/22 [00:00<00:15,  1.34it/s]

The class id and the corresponding label is:  ACPE 2


  9%|▉         | 2/22 [00:01<00:13,  1.53it/s]

The class id and the corresponding label is:  ACRU 3


 14%|█▎        | 3/22 [00:02<00:14,  1.34it/s]

The class id and the corresponding label is:  ACSA 4


 18%|█▊        | 4/22 [00:02<00:12,  1.46it/s]

The class id and the corresponding label is:  Acer 5


 23%|██▎       | 5/22 [00:03<00:10,  1.56it/s]

The class id and the corresponding label is:  BEAL 6


 27%|██▋       | 6/22 [00:03<00:09,  1.62it/s]

The class id and the corresponding label is:  BEPA 7


 32%|███▏      | 7/22 [00:04<00:10,  1.44it/s]

The class id and the corresponding label is:  Conifere 8


 36%|███▋      | 8/22 [00:05<00:09,  1.55it/s]

The class id and the corresponding label is:  FAGR 9


 41%|████      | 9/22 [00:05<00:08,  1.61it/s]

The class id and the corresponding label is:  Feuillus 10


 45%|████▌     | 10/22 [00:06<00:07,  1.68it/s]

The class id and the corresponding label is:  LALA 11


 50%|█████     | 11/22 [00:06<00:06,  1.73it/s]

The class id and the corresponding label is:  Mort 12


 55%|█████▍    | 12/22 [00:07<00:05,  1.72it/s]

The class id and the corresponding label is:  PIGL 13


 59%|█████▉    | 13/22 [00:08<00:05,  1.76it/s]

The class id and the corresponding label is:  PIMA 14


 64%|██████▎   | 14/22 [00:08<00:04,  1.79it/s]

The class id and the corresponding label is:  PIRU 15


 68%|██████▊   | 15/22 [00:09<00:03,  1.80it/s]

The class id and the corresponding label is:  PIST 16


 73%|███████▎  | 16/22 [00:09<00:03,  1.81it/s]

The class id and the corresponding label is:  POGR 17


 77%|███████▋  | 17/22 [00:10<00:02,  1.79it/s]

The class id and the corresponding label is:  POTR 18


 82%|████████▏ | 18/22 [00:10<00:02,  1.82it/s]

The class id and the corresponding label is:  Picea 19


 86%|████████▋ | 19/22 [00:11<00:01,  1.80it/s]

The class id and the corresponding label is:  Populus 20


 91%|█████████ | 20/22 [00:11<00:01,  1.82it/s]

The class id and the corresponding label is:  THOC 21


 95%|█████████▌| 21/22 [00:12<00:00,  1.76it/s]

The class id and the corresponding label is:  TSCA 22


100%|██████████| 22/22 [00:13<00:00,  1.68it/s]
