### Hierarchical loss with merged picea and populus + Acer genus class

In [125]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
from PIL import Image
import numpy as np
import glob
from tqdm import tqdm

## Labels Creation (According to before)

In [126]:
myriam_classes_sp = {'ABBA': 1,
 'ACPE': 2,
 'ACRU': 3,
 'ACSA': 4,
 'Acer' : 5,
 'BEAL': 6,
 'BEPA': 7,
 'FAGR': 8,
 'LALA': 9,
 'Mort': 10,
 'Picea': 11,
 'PIST': 12,
 'Populus': 13,
 'THOC': 14,
 'TSCA': 15}

In [127]:
myriam_classes_ge = {"ABBA": 1, "Acer_sp": 2, "Betula": 3, "FAGR": 4, "LALA": 5, "Mort": 6, "Picea": 7, "PIST": 8, "Populus": 9, "THOC": 10, "TSCA": 11}
ge_sp_mapping = {1: 1, 2: [2, 3, 4, 5], 3: [6, 7], 4: 8, 5 : 9, 6 : 10, 7: 11, 8: 12, 9: 13, 10: 14, 11: 15}

In [128]:
myriam_classes_fa = {"Conifer": 1, "Non_Conifer": 2, "Palm": 3, "Mort": 4}
fa_ge_mapping = {1: [1, 7, 8, 10, 11], 2: [2, 3, 4, 9], 3: 5, 4: 6}

In [129]:
def create_level_masks(img):
    img = np.array(Image.open(img))
    
    img_sp = make_sp_lvl(img)
    img_sp_copy = img_sp.copy()
    img_fam = make_fa_lvl(img_sp_copy)
    
    return Image.fromarray(img_sp), Image.fromarray(img_fam)
    
def make_sp_lvl(img):
    vals_sp = list(ge_sp_mapping.values())
    keys_sp = list(ge_sp_mapping.keys())
    for sp, gen in zip(keys_sp, vals_sp):
        if isinstance(gen, list):
            for i in gen:
                img[img == i] = sp
        else:
            img[img == gen] = sp
            
    return img

def make_fa_lvl(img):
    vals_fa = list(fa_ge_mapping.values())
    keys_fa = list(fa_ge_mapping.keys())
    for ge, fam in zip(keys_fa, vals_fa):
        if isinstance(fam, list):
            for i in fam:
                img[img == i] = ge
        else:
            img[img == fam] = ge
            
    return img

## Probabilities aggregation

In [18]:
preds = torch.rand((1, 14, 768, 768))
softmax = nn.LogSoftmax(dim=1) #nn.Softmax2d() 
y_hat = softmax(preds)

### Genus level

In [19]:
y_hat_list = [y_hat[:, 0, :, :],y_hat[:, 1:4, :, :],y_hat[:, 4:6, :, :],y_hat[:, 6, :, :],y_hat[:, 7, :, :], y_hat[:, 8, :, :], y_hat[:, 9, :, :],y_hat[:, 10, :, :],y_hat[:, 11, :, :]
,y_hat[:, 12, :, :], y_hat[:, 13, :, :]]

In [20]:
for i, tensor in enumerate(y_hat_list):
    if len(tensor.shape) == 3:
        y_hat_list[i] = tensor.unsqueeze(1)
    print(tensor.shape)

torch.Size([1, 768, 768])
torch.Size([1, 3, 768, 768])
torch.Size([1, 2, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])


In [21]:
ge_sp_mapping

{1: 1,
 2: [2, 3, 4],
 3: [5, 6],
 4: 7,
 5: 8,
 6: 9,
 7: 10,
 8: 11,
 9: 12,
 10: 13,
 11: 14}

In [22]:
def aggregate_probabilites(tensor_list):
    first_el = tensor_list[0]
    if len(first_el.shape) == 4 and first_el.shape[1] > 1:
            first_el = torch.sum(first_el, dim=1).unsqueeze(1)

    for tensor in y_hat_list[1:]:
        if len(tensor.shape) == 4 and tensor.shape[1] > 1:
            tensor = torch.sum(tensor, dim=1).unsqueeze(1)
        first_el = torch.cat((first_el, tensor), dim=1)
    return first_el

In [23]:
sp_tensor = aggregate_probabilites(y_hat_list)

In [24]:
sp_tensor.shape

torch.Size([1, 11, 768, 768])

In [25]:
print(print(sp_tensor.sum(1)))

tensor([[[-37.5506, -37.6991, -37.4220,  ..., -37.4855, -37.4066, -37.3186],
         [-37.6034, -37.4572, -37.6428,  ..., -37.3591, -37.8006, -37.7853],
         [-37.6374, -37.5842, -37.3989,  ..., -37.5589, -37.4889, -37.6028],
         ...,
         [-37.4531, -37.7999, -37.4906,  ..., -37.5646, -37.3306, -37.4322],
         [-37.4706, -37.5696, -37.5055,  ..., -37.4378, -37.4557, -37.2901],
         [-37.6603, -37.5226, -37.3923,  ..., -37.3896, -37.4362, -37.6500]]])
None


### Family level

In [26]:
conifer_fa = sp_tensor[:, 0, :, :] + sp_tensor[:, 6, :, :] + sp_tensor[:, 7, :, :] + sp_tensor[:, 9, :, :] + sp_tensor[:, 10, :, :]
nonconifera_fa = sp_tensor[:, 1, :, :] + sp_tensor[:, 2, :, :] + sp_tensor[:, 3, :, :] + sp_tensor[:, 8, :, :] 
palm_fa = sp_tensor[:, 4, :, :]
dead_fa = sp_tensor[:, 5, :, :]

family_tensor_list = [conifer_fa, nonconifera_fa, palm_fa, dead_fa]

In [27]:
for i, tensor in enumerate(family_tensor_list):
    if len(tensor.shape) == 3:
        family_tensor_list[i] = tensor.unsqueeze(1)
    print(tensor.shape)

torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])
torch.Size([1, 768, 768])


In [28]:
# fa_tensor = aggregate_probabilites(family_tensor_list)

In [29]:
fa_tensor = torch.cat((family_tensor_list[0], family_tensor_list[1], family_tensor_list[2], family_tensor_list[3]), dim=1)

In [30]:
fa_tensor.shape

torch.Size([1, 4, 768, 768])

In [35]:
print(print(fa_tensor.sum(1)))

tensor([[[-37.5506, -37.6991, -37.4220,  ..., -37.4855, -37.4066, -37.3186],
         [-37.6034, -37.4572, -37.6428,  ..., -37.3590, -37.8006, -37.7853],
         [-37.6374, -37.5842, -37.3989,  ..., -37.5589, -37.4889, -37.6028],
         ...,
         [-37.4531, -37.7999, -37.4906,  ..., -37.5646, -37.3306, -37.4322],
         [-37.4706, -37.5696, -37.5055,  ..., -37.4378, -37.4557, -37.2901],
         [-37.6603, -37.5226, -37.3923,  ..., -37.3896, -37.4362, -37.6500]]])
None


### Test Loss Calculation w/ Log-Softmax+NLL Loss

### In implementation for the loss, will need to add the 0th channel (output will have 15 channels)

In [36]:
spe_target_tensor = torch.tensor(np.array(Image.open(label_list[1]))).unsqueeze(0)

In [37]:
gen_target_tensor = torch.tensor(np.array(sp_img)).unsqueeze(0)
fam_target_tensor = torch.tensor(np.array(fam_img)).unsqueeze(0) 

In [38]:
print(torch.unique(spe_target_tensor), torch.unique(gen_target_tensor), torch.unique(fam_target_tensor))

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12], dtype=torch.uint8) tensor([0, 1, 2, 3, 4, 5, 6, 7, 9], dtype=torch.uint8) tensor([0, 1, 2, 3, 4], dtype=torch.uint8)


In [39]:
spe_preds = y_hat
gen_preds = sp_tensor
fam_preds = fa_tensor

In [40]:
loss_sp = nn.NLLLoss()
loss_ge = nn.NLLLoss()
loss_fa = nn.NLLLoss()

In [41]:
print(spe_preds.shape, spe_target_tensor.shape)
print(gen_preds.shape, gen_target_tensor.shape)
print(fam_preds.shape, fam_target_tensor.shape)

torch.Size([1, 14, 768, 768]) torch.Size([1, 768, 768])
torch.Size([1, 11, 768, 768]) torch.Size([1, 768, 768])
torch.Size([1, 4, 768, 768]) torch.Size([1, 768, 768])


In [42]:
output_sp = loss_sp(spe_preds, spe_target_tensor.long())
output_ge = loss_ge(gen_preds, gen_target_tensor.long())
output_fa = loss_fa(fam_preds, fam_target_tensor.long())

In [43]:
print(output_ge, output_ge)

tensor(3.9690) tensor(3.9690)


### Generating three level labels

In [137]:
train_labels_path = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels'
save_path_genus = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_genus'
save_path_family = '/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_family'

In [138]:
hierarchical_label_list = glob.glob(train_labels_path + '/**/*.png', recursive=True)

In [139]:
# print(hierarchical_label_list[0].replace('labels', 'labels_genus'))
len(hierarchical_label_list)

1076

In [140]:
for item in tqdm(hierarchical_label_list):
    sp_img, fam_img = create_level_masks(item)
    sp_img.save(item.replace('labels', 'labels_genus'))
    fam_img.save(item.replace('labels', 'labels_family'))

100%|██████████| 1076/1076 [01:15<00:00, 14.27it/s]


In [32]:
hierarchical_label_list[100]

'/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels/zone1/1234934_1492306.png'

In [123]:
# x1 = Image.open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels/zone2/1234948_1492258.png')
# x2 = Image.open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_genus/zone2/1234948_1492258.png')
# x3 = Image.open('/home/mila/v/venkatesh.ramesh/scratch/tree_data/dataset_myriam_split_hierarchical/train/labels_family/zone2/1234948_1492258.png')

### Vis test

In [1]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
from PIL import Image
import numpy as np
import glob
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
