# Hierarchical Training of inception V3 model on aerial data
instead of training a model to classify images into species directly, we will train a model to first classify them into leaf type, then two models which classify broadleaf and pineneedle trees into their respective families. This should hopefully result in a more general and more accurate classification

### Loading in the aerial hierarchy dictionary and image data

In [23]:
import numpy as np
import json

#loading hierarchy data
with open(r'..\..\data\aerial_hierarchy.json') as file:
    aerial_hierarchy = json.load(file)

with open(r'..\..\data\sentinel_hierarchy.json') as file:
    sentinel_hierarchy = json.load(file)

In [None]:

#loading image data
aerial_images = np.load(r"C:\Users\bench\OneDrive\Documents\EMAT Year 3\MDM3\Phase C\ratio_adjusted_aerial_dataset\aerial_99_images.npy")

In [24]:
print(aerial_images.shape)
print(max(aerial_hierarchy["broadleaf"]["oak"]["Quercus_rubra"]["index"]))

max_list = []
for family in sentinel_hierarchy["broadleaf"]:
    for species in sentinel_hierarchy["broadleaf"][family]:
        index_list = sentinel_hierarchy["broadleaf"][family][species]["index"]
        index_list = list(map(int, index_list))
        max_list.append(max(index_list))
        

print(max_list)
print(max(max_list))




(17618, 304, 304, 4)
17522
[17698, 17277, 15942, 5158, 1194, 6014, 17793, 14126, 1746, 14030, 2316]
17793


# Training the initial binary classifier

### Important training parameters

In [None]:
import torch
num_classes = 2
feature_extract = True

In [36]:
# forming binary labels
foliage_labels = np.zeros((aerial_images.shape[0],), dtype=int)
# 1 represents broadleaf and zero conifers
for family in aerial_hierarchy["broadleaf"]:
    for species in aerial_hierarchy["broadleaf"][family]:
        for index in aerial_hierarchy["broadleaf"][family][species]["index"]:
            foliage_labels[int(index)] = 1

print(foliage_labels[0:1000])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 0 1 1 1 1 1 1 1 1 0 

### forming train val test split

In [None]:
from sklearn.model_selection import train_test_split

train_images, rest_images, train_labels, rest_labels = train_test_split(aerial_images, foliage_labels, 
                                                                        test_size=0.2, random_state=2, 
                                                                        stratify=foliage_labels)

test_images, val_images, test_labels, val_labels = train_test_split(rest_images, rest_labels, 
                                                                    test_size=0.5, random_state=2, 
                                                                    stratify=foliage_labels)

### initialising model

In [None]:
from training_funcs import initialize_inception_model

foliage_model, input_size = initialize_inception_model(num_classes, feature_extract, use_pretrained=True)

#setting grads for training
if feature_extract:
  params_to_update = []
  for name,param in model_ft.named_parameters():
    if param.requires_grad==True:
      params_to_update.append(param)
      print("\t",name)
else:
  raise Exception("not yet implemented")
  params_to_update = model_ft.parameters()

### Creating data loaders

In [None]:
from training_funcs import CustomDataset
    
#creating training and validation pytorch datasets
training_dataset = CustomDataset(train_images, train_labels, input_size)
val_dataset = CustomDataset(val_images, train_labels, input_size)

In [None]:
batch_size = 32
# Create training and validation dataloaders
dataloaders_dict = {'train': torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True),
                    'val': torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)}


### Training the binary model

In [None]:
#checking the params we are training
params_to_update = []
for name,param in foliage_model.named_parameters():
  if param.requires_grad==True:
    params_to_update.append(param)
    print("\t",name)

print("number of parameters to train =",len(params_to_update))
for i, param in enumerate(params_to_update):
  print("parameter {}:".format(i),param.shape)

In [None]:
from torch import nn
from torch import optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("sending model to",device)
# Send the model to GPU
model_ft = foliage_model.to(device)
# Setup the loss fxn
criterion = nn.CrossEntropyLoss()
#initializing optimizer with hyperparameters determined from the random search
optimizer_ft = optim.AdamW(params_to_update, lr=0.0005, betas=(0.9, 0.995), weight_decay=0.015)
#setting the number of epochs
num_epochs = 300

In [None]:
from training_funcs import train_model
#fixing random seed for consistent results
torch.manual_seed(0)

# Train and evaluate
model_ft, convergence_dict = train_model(model_ft,
                             dataloaders_dict,
                             criterion,
                             optimizer_ft,
                             num_epochs=num_epochs,
                             is_inception=True,
                             tensorboard_writer=None,
                             early_stopping=False,
                             device=device)

torch.save(model_ft.state_dict(),'foliage_model.pth')