In [3]:
import geopandas as gpd
from shapely.geometry import Polygon
from shapely.ops import transform
from shapely.ops import unary_union
from rasterio.transform import from_bounds
import rasterio
from rasterio import mask

# Utils
import json
import os
import numpy as np
from tqdm import tqdm
import glob

### Open and re-project

In [None]:
tree_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data'
raster_path = os.path.join(tree_path, 'Tree_data/2021-09-02-sbl-cloutier-z3-MS/2021-09-02-sbl-cloutier-z3-UTM18-MS.tif')
save_path = os.path.join(tree_path,'labels/zone3_merged.tif') 

In [None]:
gdf = gpd.read_file(os.path.join('/home/mila/v/venkatesh.ramesh/scratch/tree_data/geojson_polygons/individual_regions/val_region_poly.geojson'))

In [None]:
gdf.crs

In [None]:
with rasterio.open(raster_path, "r") as src:
    raster_meta = src.meta
    
print('Raster CRS: ', raster_meta['crs'])

In [None]:
gdf_projected = gdf.to_crs(raster_meta['crs'])

In [None]:
gdf_projected.head()

In [None]:
count_val = gdf_projected['Label'].value_counts()

In [None]:
count_test = gdf_projected['Label'].value_counts()

In [None]:
count = gdf_projected['Label'].value_counts()

In [None]:
count_val

In [None]:
count_test

In [None]:
count

In [None]:
labels, frequency = np.unique(gdf_projected['Label'], return_counts=True)

In [None]:
frequency

In [None]:
len(labels)

### Read CSVs

In [None]:
import pandas as pd
import os

In [None]:
train_dataset = pd.read_csv('/home/mila/v/venkatesh.ramesh/scratch/tree_data/splits/geographic_train_256.csv')

In [None]:
# print(train_dataset['la'])
train_dataset

In [None]:
x = train_dataset['tiles'][0]

In [None]:
print(x)

In [None]:
splits = x.split('/')[:-3]
print(splits)

In [None]:
print('/' + os.path.join(*splits))

### Removing classes and ignore index

remove_classes = QURU, PRPE, POBA, BEPO

ignore_index = Autres, Betula 

coarser_labels = Acer, Betula, Conifere, Feuillus

ignore_index = Autres, Betula, QURU, PRPE, POBA, BEPO


ignore_index_maybe = Acer, Conifere, Feuillus (They have high frequency of appearence so maybe iteration 2)


TSCA has least annotation but it is used.

**Iteration 1:**

1. remove_classes = QURU, PRPE, POBA, BEPO, Betula, Autres, OSVI (very low frequency classes)

**Iteration 2:**

1. remove_classes = QURU, PRPE, POBA, BEPO, Betula, Autres, OSVI (very low frequency classes)
2. ignore_index = Acer, Conifere, Feuillus

In [2]:
class_dict = {"ABBA": 1, "ACPE": 2, "ACRU": 3, "ACSA": 4, "Acer": 5, "BEAL": 6, "BEPA": 7, "BEPO": 8, "Betula": 9, "Conifere": 10, "FAGR": 11, "FRNI": 12, "Feuillus": 13, "LALA": 14, "Mort": 15, "PIGL": 16, "PIMA": 17, "PIRU": 18, "PIST": 19, "POBA": 20, "POGR": 21, "POTR": 22, "PRPE": 23, "Picea": 24, "Populus": 25, "QURU": 26, "THOC": 27, "TSCA": 28, "Autres": 29, "OSVI": 30}

In [2]:
image_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/labels_removed_classes_2/'
images_list = glob.glob(image_path + '/**/*.png', recursive=True)

In [3]:
len(images_list)

1076

### Check max

In [16]:
from PIL import Image

for img in tqdm(images_list):
    temp = np.array(Image.open(img))
    min = np.min(temp)
    max = np.max(temp)
    
    if min < 0 or max > 23:
#         print(img)
#         print(min, max)
#         print('HI')
        temp[temp > 23] = 0
        temp = Image.fromarray(temp)
        temp.save(img)
    
#     for id in remove_labels:
#         temp[temp == id] = 0
        
    temp = Image.fromarray(temp)
    temp.save(img)

100%|██████████| 280/280 [00:00<00:00, 310.32it/s]


In [None]:
test = Image.open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/labels_removed_classes_1/zone1/1234940_1492354.png')

In [10]:
unq = np.unique(np.array(test))

In [11]:
print(unq)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 23 31]


### Make labels ordering sequential

In [5]:
remove_list = ["Acer", "Feuillus", "PIMA", "POGR"]

remove_labels = [myriam_dict[v] for v in remove_list]

In [7]:
from PIL import Image

for img in tqdm(images_list):
    temp = np.array(Image.open(img))
    
    for id in remove_labels:
        temp[temp == id] = 0
        
    temp = Image.fromarray(temp)
    temp.save(img)

100%|██████████| 1076/1076 [00:16<00:00, 64.61it/s]


In [8]:
new_dict = myriam_dict

for key in remove_list:
    del new_dict[key]

In [9]:
with open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/classes_onlymyriam.json', 'w') as fp:
        json.dump(new_dict, fp)

In [10]:
len(new_dict)

14

In [11]:
with open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/classes_onlymyriam.json', 'r') as fp:
        new_dict = json.load(fp)

In [12]:
new_dict

{'ABBA': 1,
 'ACPE': 2,
 'ACRU': 3,
 'ACSA': 4,
 'BEAL': 6,
 'BEPA': 7,
 'FAGR': 8,
 'LALA': 10,
 'Mort': 11,
 'PIST': 13,
 'Picea': 15,
 'Populus': 16,
 'THOC': 17,
 'TSCA': 18}

In [18]:
changed_dict = {'ABBA': 1,
 'ACPE': 2,
 'ACRU': 3,
 'ACSA': 4,
 'Acer': 5,
 'BEAL': 6,
 'BEPA': 7,
 'Conifere': 8,
 'FAGR': 9,
 'FRNI': 10,
 'Feuillus': 11,
 'LALA': 12,
 'Mort': 13,
 'PIGL': 14,
 'PIMA': 15,
 'PIRU': 16,
 'PIST': 17,
 'POGR': 18,
 'POTR': 19,
 'Picea': 20,
 'Populus': 21,
 'THOC': 22,
 'TSCA': 23}

In [13]:
# myriam_dict =  {'ABBA': 1,
#  'ACPE': 2,
#  'ACRU': 3,
#  'ACSA': 4,
#  'Acer': 5, #Ignore this
#  'BEAL': 6,
#  'BEPA': 7,
#  'FAGR': 8,
#  'Feuillus': 9, # Ignore this
#  'LALA': 10,
#  'Mort': 11,
#  'PIMA': 12, # Ignore this
#  'PIST': 13,
#  'POGR': 14, # Ignore this
#  'Picea': 15,
#  'Populus': 16,
#  'THOC': 17,
#  'TSCA': 18 }
# #  'Conifere': 19, Ignore index
# #  'FRNI': 20, Ignore index
# #  'PIGL': 21, Ignore index
# #  'PIRU': 16,
# #  'POTR': 19, # Ignore index
 
 

In [17]:
with open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/classes_myriam_modified.json', 'w') as fp:
        json.dump(myriam_dict, fp)

In [20]:
image_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/labels_removed_classes_2/'
images_list = glob.glob(image_path + '/**/*.png', recursive=True)

In [21]:
len(images_list)

1076

In [22]:
# change_list = {10:8, 11:9, 12:10, 13:11, 14:12, 15:13, 16:14, 17:15, 18:16, 19:17, 21:18, 22:19, 24:20, 25:21, 27:22, 28:23}
# {'ABBA': 1,
#  'ACPE': 2,
#  'ACRU': 3,
#  'ACSA': 4,
#  'BEAL': 6,
#  'BEPA': 7,
#  'FAGR': 8,
#  'LALA': 10,
#  'Mort': 11,
#  'PIST': 13,
#  'Picea': 15,
#  'Populus': 16,
#  'THOC': 17,
#  'TSCA': 18}

change_list = {6:5, 7:6, 8:7, 10:8, 11:9, 13:10, 15:11, 16:12, 17:13, 18:14}

In [23]:
x = change_list.keys()

In [24]:
x

dict_keys([6, 7, 8, 10, 11, 13, 15, 16, 17, 18])

In [25]:
from PIL import Image

for img in tqdm(images_list):
    temp = np.array(Image.open(img))
    
    
    for id in x:
        temp[temp == id] = change_list[id]
    
    min = np.min(temp)
    max = np.max(temp)
    if min < 0 or max > 14:
        temp[temp > 18] = 0
        temp[temp < 0] = 0
        
#     print(np.unique(temp))
    
    temp = Image.fromarray(temp)
    temp.save(img)

100%|██████████| 1076/1076 [00:28<00:00, 37.35it/s]


### Make images and labels list

In [2]:
import csv
import random
import pandas as pd
from PIL import Image

In [104]:
image_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/images'
images_list = glob.glob(image_path + '/**/*.png', recursive=True)

label_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_merged/'
labels_list = glob.glob(label_path + '/**/*.png', recursive=True)

In [87]:
# for img in labels_list:
#     x = np.array(Image.open(img))
#     print(np.unique(x))

In [105]:
x = Image.open(labels_list[30])

In [106]:
# x

In [107]:
x = np.array(x)
x.shape

(768, 768, 3)

In [108]:
print(np.unique(x))

[0 1 2 3 4]


In [109]:
len(images_list), len(labels_list)

(1076, 1076)

In [110]:
images_list.sort(), labels_list.sort()

(None, None)

In [111]:
images_list[4], labels_list[4]

('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/images/zone1/1234916_1492357.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_merged/zone1/1234916_1492357.png')

In [112]:
data = list(zip(images_list, labels_list))
random.shuffle(data)

images, labels = zip(*data)

In [113]:
images_val = list(images)[:215]
labels_val = list(labels)[:215]

images_train = list(images)[215:]
labels_train = list(labels)[215:]

In [98]:
images_test = list(images)
labels_test = list(labels)

In [99]:
images_test[4], labels_test[4]

('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/images/zone3/1234876_1492306.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone3/1234876_1492306.png')

In [114]:
images_train[4], labels_train[4]

('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/images/zone2/1234969_1492303.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_merged/zone2/1234969_1492303.png')

In [116]:
df = pd.DataFrame(data={"tiles": images_train, "labels": labels_train})
df.to_csv('/home/mila/v/venkatesh.ramesh/scratch/tree_data/splits/hierarchical/train_768.csv', sep=',', index=False)

In [102]:
df = pd.DataFrame(data={"tiles": images_list, "labels": labels_list})
df.to_csv('/home/mila/v/venkatesh.ramesh/scratch/tree_data/splits/hierarchical/test_768.csv', sep=',', index=False)

In [66]:
d1 = pd.read_csv('/home/mila/v/venkatesh.ramesh/scratch/tree_data/splits/pretiled_train_768.csv')

In [67]:
images, labels = list(d1['tiles']), list(d1['labels'])

In [57]:
images.sort(), tiles.sort()

(None, None)

In [68]:
# images[120], labels[120]
d1.iloc[120]['tiles'], d1.iloc[120]['labels']

('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/images/zone1/1234976_1492363.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/labels_removed_classes_1/zone1/1234976_1492363.png')

In [15]:
d1['tiles'][0]

'/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/images/zone1/1234970_1492336.png'

In [16]:
d1['labels'][0]

'/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/train/labels_removed_classes_1/zone1/1234958_1492384.png'

### Create masks for hierarchical loss

In [9]:
import csv
import random
import pandas as pd
from PIL import Image
import glob
import numpy as np

In [3]:
# image_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/test/images'
# images_list = glob.glob(image_path + '/**/*.png', recursive=True)

# label_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/test/labels_removed_classes_2/'
# labels_list = glob.glob(label_path + '/**/*.png', recursive=True)

In [4]:
# print(len(labels_list), len(images_list))

### Merge spe, gen, fam

In [21]:
species_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels'
genus_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_genus'
family_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_family'

In [22]:
species = glob.glob(species_path + '/**/*.png', recursive=True)
genus = glob.glob(genus_path + '/**/*.png', recursive=True)
family = glob.glob(family_path + '/**/*.png', recursive=True)

In [23]:
species.sort()
genus.sort()
family.sort()

In [24]:
print(len(species), len(genus), len(family))

280 280 280


In [25]:
for spe, gen, fam in zip(species, genus, family):
    spe_img = np.array(Image.open(spe))
    gen_img = np.array(Image.open(gen))
    fam_img = np.array(Image.open(fam))
    
#     print(spe_img.shape, gen_img.shape, fam_img.shape)
#     break
    rgb = Image.fromarray(np.dstack((spe_img[:, :, 0], gen_img[:, :, 0], fam_img[:, :, 0])))
    rgb.save(spe.replace('labels', 'labels_merged'))
#     print(np.unique(rgb))

In [36]:
rgb = np.array(Image.open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone1/1234940_1492336.png'))

In [37]:
print(np.unique(rgb[:, :, 0]), np.unique(rgb[:, :, 1]), np.unique(rgb[:, :, 2]))

[ 0  1  2  3  4  5  6  7  8  9 10 11 13 14] [ 0  1  2  3  4  5  6  7  9 10] [0 1 2 3 4]


In [38]:
rgb = Image.fromarray(rgb * 15)

In [40]:
# rgb

### Change labels path in CSVs

In [36]:
train_csv = pd.read_csv('/home/mila/v/venkatesh.ramesh/scratch/tree_data/splits/cross_validation/split3/train_cv3_768.csv')

In [37]:
mod_labels = []
for item in train_csv['labels']:
    item = item.replace('dataset_myriam_split', 'dataset_myriam_split_hierarchical')
    mod_labels.append(item.replace('labels_removed_classes_2', 'labels_merged')) 

In [39]:
mod_labels[:5]

['/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone1/1234931_1492333.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone1/1234931_1492336.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone1/1234931_1492339.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone1/1234931_1492342.png',
 '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone1/1234934_1492333.png']

In [40]:
tiles = []
for til in train_csv['tiles']:
    tiles.append(til)

In [41]:
print(tiles[100], mod_labels[100])

/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split/test/images/zone3/1234885_1492336.png /home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/test/labels_merged/zone3/1234885_1492336.png


In [42]:
print(len(tiles), len(mod_labels))

280 280


In [43]:
df = pd.DataFrame(data={"tiles": tiles, "labels": mod_labels})
df.to_csv('/home/mila/v/venkatesh.ramesh/scratch/tree_data/splits/hierarchical/test_768.csv', sep=',', index=False)