In [None]:
# draft code used to be in class MyLogger(Logger)'s __init()__ method
# used to set up initial history keys
num_output_layers = 6

self.num_output_layers = num_output_layers
        history_per_output_layer = {
            # per epoch
            'loss':[],
            'acc':[],
            'f1_score':[],
            # these 2 will be replaced with step_y_pred and step_y_true respectively
            'y_pred':[],
            'y_true':[],
        } 

def pre_suf_fix_dict(dict, fix_val='output_layer', prefix=True):
    return {f"{fix_val}_{key}" if prefix else f"{key}_{fix_val}" : value for key, value in dict.items()}

# to create val/test prefixed keys
history_per_output_layer.update(
    **pre_suf_fix_dict(history_per_output_layer, fix_val='val', prefix=True),
    **pre_suf_fix_dict(history_per_output_layer, fix_val='test', prefix=True),
)
# to create step_ prefixed keys
history_per_output_layer.update(
    **pre_suf_fix_dict(history_per_output_layer, fix_val='step', prefix=True),
)
# removing y_pred, y_true and keeping step_y_pred and step_y_true
history_per_output_layer.pop('y_pred')
history_per_output_layer.pop('y_true')

# fix_val == suffix without "_"; to provide flexibility in using fix_val as prefix as well (if we wanted to!)
self.out_layers_fix_val = 'output_layer' 

# pre_suf_fix_dict function is used to create unique key names to merge the `num_output_layers` histories into one history dictionary
history_ol_lists = [pre_suf_fix_dict(history_per_output_layer, 
                                    fix_val=f'{self.out_layers_fix_val}_{i}', 
                                    prefix=False) 
                    for i in range(self.num_output_layers)] # ol == output layers

# 'loss' here is for the final loss calculated by the loss aggregator function which uses all `num_output_layers` losses
self.history = {'loss':[], 'val_loss':[], 'test_loss':[]}
self.history.update(**pre_suf_fix_dict(self.history, 'step', prefix=True))
# merging the `num_output_layers` histories into one history dictionary
for hist_dict in history_ol_lists:
    self.history.update(**hist_dict)

self.hist_keys = list(self.history.keys())

In [None]:
# removed manual logging
import pytorch_lightning as pl
from torch.nn import functional as F
import torchmetrics

class HierarchalModelPL(pl.LightningModule):
   def __init__(self, hierarchical_model, num_classes_per_layer:list, loss_weights=None, dummy_input_size=(32, 3, 306, 306), the_device='cpu'):
      super(HierarchalModelPL, self).__init__()
      self.hierarchical_model = hierarchical_model # .to(the_device)
      self.num_output_layers = len(num_classes_per_layer)
      self.loss_weights = [1]*self.num_output_layers if loss_weights is None else loss_weights
      self.num_output_layers = self.num_output_layers
      
      # useful for writing computational graph in tensorboard, summary(), etc
      self.example_input_array = torch.randn(*dummy_input_size) # , device=the_device

      self.the_device = the_device if 'cpu' in the_device else 'cuda'
      self.metrics = ['acc', 'f1_score']
      def _create_metric_func(metric_type:str, num_classes):
         if 'acc' in metric_type.lower():
               return torchmetrics.Accuracy(task='multiclass', 
                                          num_classes=num_classes, 
                                          average="micro").to(the_device)
         elif 'f1' in metric_type.lower():
               return torchmetrics.F1Score(task="multiclass", 
                                          num_classes=num_classes, 
                                          average=None).to(the_device)
      
      #                                         step/epoch vvv
      # max indexing possible: self.metric_funcs['f1_score'][1][num_output_layers-1][2]
      #                                       softmax output layers ^^^            ^^^  train/val/test 
      # output example of metric_funcs['acc']
      # (note: MCA == MulticlassAccuracy()):
      # {'acc': [ 
      #   [ #step
      #    [MCA, MCA, MCA], # softmax_output_layer_0 ; train/val/test
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA], # ...
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA] # # softmax_output_layer_5 ; train/val/test
      #   ],
      #   [ #epoch
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA]
      #   ]
      # ]
      
      self.metric_funcs = {}
      for metric in self.metrics:
         # i --> step/epoch, j --> num_output_layers-1, k --> train/val/test
         self.metric_funcs[metric] = [[[_create_metric_func(metric, num_classes_per_layer[j]) for k in range(3)] 
                                       for j in range(self.num_output_layers)] 
                                       for i in range(2)]
      
      # fix_val == suffix without "_"; to provide flexibility in using fix_val as prefix as well (if we wanted to!)
      self.history_layers_fix_val = 'output_layer' 


   def forward(self, x):
      # len(hier_y_probs) == num_hierarchy_output_layers, 
      # while each element is 2D tensor of shape (batch_size, num_classes of the i_th output layer in hierarchy)
      hier_y_pred = []
      for tensor_output in self.hierarchical_model(x):
         hier_y_pred.append(F.softmax(tensor_output, dim=1))
      hier_y_pred
      return hier_y_pred
   
   def training_step(self, batch, batch_idx):
      ds_prefix = ''
      metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
      # add other key:value pairs here or pass the entire metrics_dict if you want. just make sure 'loss' key is present
      return metrics_dict
    
   def training_epoch_end(self, outputs) -> None:
      # 'outputs' argument here contains values from what was returned from training_step()
      ds_prefix = ''
      self._epoch_end_logic(outputs, ds_prefix)

   def validation_step(self, batch, batch_idx):
      ds_prefix = 'val_'
      metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
      return metrics_dict
   
   def validation_epoch_end(self, outputs) -> None:
      ds_prefix = 'val_'
      self._epoch_end_logic(outputs, ds_prefix)
   
   def test_step(self, batch, batch_idx):
      ds_prefix = 'test_'
      metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
      return metrics_dict
   
   def test_epoch_end(self, outputs) -> None:
      ds_prefix = 'test_'
      self._epoch_end_logic(outputs, ds_prefix=ds_prefix)
   

   def _store_metric(self, metric_name, metric_val):
      self.logger.log(metric_name, metric_val)
      # try:
      #    self.log(metric_name, metric_val)
      # except Exception:
      #    # then metric_val is not scalar, then this means:
      #    # metric_name is an f1 score metric 
      #    # so, we just take the mean value
      #    self.log(metric_name, metric_val.mean())
      # # putting this here ensures we don't include the mean of f1 score, 
      # # but rather the tensor of shape (1, num_classes)

   def _step_logic(self, batch, ds_prefix=''):
      # ds_prefix == dataset_prefix
      if 'val' in ds_prefix:
         ds_type_idx = 1
      elif 'test' in ds_prefix:
         ds_type_idx = 2
      else:
         ds_type_idx = 0
      ds_prefix = 'step_' + ds_prefix

      x, y = batch
      hier_y_pred = self(x)

      # y now has shape (batch_size, len(hier_y_pred))
      # in other words, each row now consists of `len(hier_y_pred)` output layers' labels of 1 sample
      # side note: len(hier_y_pred) == self.num_output_layers
      y = torch.tensor([labelToHierarchy[int(y[i])] for i in range(len(y))], dtype=int, device=self.the_device) # alternatively, range(y.size()[0])
      # transpose y to take each row (batch of labels) with its corresponding hier_y_pred row (batch of predicted labels)
      # in other words, each row now consists of `batch_size` labels of the i_th output layer
      y = y.T
      # softmax_output_layer elements each has shape (batch_size, num_classes of the i_th output layer in hierarchy)
      # side note: the wording of "i_th output layers" refers to the strings mentioned in 
      # output_order argument of HierarchalModel() used in create_hnn_model_arch()

      losses_dict = {}
      other_metrics_dict = {}
      for out_layer_idx, softmax_output_layer in enumerate(hier_y_pred): 
         # 'cur' refers to current output layer
         y_cur = y[out_layer_idx]
         y_pred_cur = softmax_output_layer.to(self.the_device)
         ol_suffix = f'_{self.history_layers_fix_val}_{out_layer_idx}'

         # log step metrics

         loss = F.cross_entropy(y_pred_cur, y_cur)
         loss_full_metric_name = f'{ds_prefix}loss{ol_suffix}'
         self._store_metric(loss_full_metric_name, loss)
         losses_dict[loss_full_metric_name] = loss

         for metric_name in self.metrics:
            # recall the indexing of self.metric_funcs:
            # metric_name, step/epoch, num_output_layers-1, train/val/test
            metric_val = self.metric_funcs[metric_name][0][out_layer_idx][ds_type_idx](y_pred_cur, y_cur)
            full_metric_name = f'{ds_prefix}{metric_name}{ol_suffix}'
            self._store_metric(full_metric_name, metric_val)
            self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].update(y_pred_cur, y_cur)
            other_metrics_dict[full_metric_name] = metric_val
      
         f1_score = self.metric_funcs[metric_name][0][out_layer_idx][ds_type_idx](y_pred_cur, y_cur)
         self._store_metric(f'{ds_prefix}{metric_name}{ol_suffix}', f1_score)
         self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].update(y_pred_cur, y_cur)

         # storing y_pred/y_true
         # self.history[f'{ds_prefix}y_pred{ol_suffix}'].extend(y_pred_cur)
         # self.history[f'{ds_prefix}y_true{ol_suffix}'].extend(y_cur)

      final_loss = self.metric_reduce_fx(losses_dict.values(), 'weighted_sum')
      self._store_metric(f'{ds_prefix}loss', final_loss)

      # Important note: 'loss' key must be present, or else you'll get this error:
      # MisconfigurationException: In automatic_optimization, 
      # when `training_step` returns a dict, the 'loss' key needs to be present
      # side note: add `.update(losses_dict)` and `.update(other_metrics_dict)` 
      # if you want to directly use other metrics in "..._epoch_end()" methods
      return {f'loss' : final_loss} 

   def _epoch_end_logic(self, outputs, ds_prefix=''):
      if 'val' in ds_prefix:
         ds_type_idx = 1
      elif 'test' in ds_prefix:
         ds_type_idx = 2
      else:
         ds_type_idx = 0

      # log epoch metrics
      final_loss_epoch = torch.tensor([x[f'loss'] for x in outputs], dtype=float).mean()
      self._store_metric(f'{ds_prefix}loss', final_loss_epoch)

      for out_layer_idx in range(self.num_output_layers):
         ol_suffix = f'_{self.history_layers_fix_val}_{out_layer_idx}'

         for metric_name in self.metrics:
            # recall: '1' for accessing epoch func (not step func)
            metric_val_epoch = self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].compute()
            self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].reset()
            full_metric_name = f'{ds_prefix}{metric_name}{ol_suffix}'
            self._store_metric(full_metric_name, metric_val_epoch)
            if out_layer_idx == 0:
               print(metric_name)
               print(metric_val_epoch)
               print()
   
   def metric_reduce_fx(self, metric_list, agg_type='weighted_sum'):
      '''
      aggregates metric values from all `num_output_layers` into a single value
      side note: called "reduce_fx" as a reference to PyTorch Lightning's reduce_fx parameter found in self.log()
      '''
      if 'weighted' in agg_type.lower() and 'sum' in agg_type.lower():
         weighted_sum_val = 0
         for i, layer_metric_val in enumerate(metric_list):
            weighted_sum_val += self.loss_weights[i] * layer_metric_val
         final_val = weighted_sum_val

      return final_val

   def configure_optimizers(self):
      # make it self.hierarchical_model.parameters() if you defined other parameters in __init()__ which you don't want to optimize
      # side note: under the hood, pl automatically gets gradients 
      # from final_loss returned from training_step() and adjusts the models' branches accordingly
      # source: https://github.com/Lightning-AI/lightning/issues/2645#issuecomment-660681760
      return torch.optim.Adam(self.parameters(), lr=1e-3)
