In [13]:
import torchmetrics
import torch
ff = torchmetrics.Accuracy('multiclass', num_classes=10, average=None, top_k=1, threshold=0.5)
ff.update(torch.tensor([2]), torch.tensor([2]))
ff.update(torch.tensor([0]), torch.tensor([2]))
ff.update(torch.tensor([1]), torch.tensor([1]))
res = ff.compute()
ff.reset()
res
# output: tensor([0.0000, 1.0000, 0.5000])
# when using average='micro', the output: tensor(0.6667) (as 2 correct preds / 3 all preds = 0.6667)
# when using average='macro', the output: tensor(0.5000) (as (0+1+0.5)/3 = 0.5)

tensor([0.0000, 1.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000])

In [3]:
import os
from typing import List, Union
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl
import torchmetrics

class MNISTClassifier(pl.LightningModule):
    def __init__(self, the_device='cuda', num_classes=10, dummy_input_size=(32,3,306,306)):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.LazyLinear(128)
        self.fc2 = nn.LazyLinear(10)
        
        # useful for writing computational graph in tensorboard, summary(), etc
        # self.example_input_array = torch.randn(*dummy_input_size) 

        self.the_device = the_device

        def _create_metric_func(metric_type:str):
            if 'acc' in metric_type.lower():
                return torchmetrics.Accuracy(task='multiclass', 
                                             num_classes=num_classes, 
                                             average="micro").to(the_device)
            elif 'f1' in metric_type.lower():
                return torchmetrics.F1Score(task="multiclass", 
                                            num_classes=num_classes, 
                                            average=None).to(the_device)
        
        # for calculating metric of a step
        self.acc_step_funcs = [_create_metric_func('acc') for _ in range(3)] # for computing train, val, test results respectively
        self.f1_score_step_funcs = [_create_metric_func('f1') for _ in range(3)]
        # for calculating metric of an epoch
        self.acc_epoch_funcs = [_create_metric_func('acc') for _ in range(3)]
        self.f1_score_epoch_funcs = [_create_metric_func('f1') for _ in range(3)]
        
        self.history = {
            # per epoch
            'loss':[],
            'acc':[],
            'f1_score':[],
            'val_loss':[],
            'val_acc':[],
            'val_f1_score':[],
            'test_loss':[],
            'test_acc':[],
            'test_f1_score':[],
            # per step
            'step_loss':[],
            'step_acc':[],
            'step_f1_score':[],
            'val_step_loss':[],
            'val_step_acc':[],
            'val_step_f1_score':[],
            'test_step_loss':[],
            'test_step_acc':[],
            'test_step_f1_score':[],
            # for tracking pred/true values for train/val/test sets
            'step_y_pred': [],
            'step_y_true': [],
            'val_step_y_pred': [],
            'val_step_y_true': [],
            'test_step_y_pred': [],
            'test_step_y_true': [],
        }

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.softmax(x, dim=1)
    
    def _store_metric(self, metric_name, metric_val):
        try:
            self.log(metric_name, metric_val)
        except Exception as e:
            # then metric_val is not scalar, then this means:
            # metric_name is an f1 score metric 
            # so, we just take the mean value
            self.log(metric_name, metric_val.mean())
        # putting this here ensures we don't include the mean of f1 score, 
        # but rather the tensor of shape (1, num_classes)
        self.history[metric_name].append(metric_val)

    def _step_logic(self, batch, prefix=''):
        if 'val' in prefix:
            func_idx = 1
        elif 'test' in prefix:
            func_idx = 2
        else:
            func_idx = 0
        
        x, y_true = batch
        y_pred = self(x)
        # log step metrics
        loss = F.cross_entropy(y_pred, y_true)
        self._store_metric(f'{prefix}step_loss', loss)

        acc = self.acc_step_funcs[func_idx](y_pred, y_true)
        self._store_metric(f'{prefix}step_acc', acc)
        self.acc_epoch_funcs[func_idx].update(y_pred, y_true)
    
        f1_score = self.f1_score_step_funcs[func_idx](y_pred, y_true)
        self._store_metric(f'{prefix}step_f1_score', f1_score)
        self.f1_score_epoch_funcs[func_idx].update(y_pred, y_true)

        # storing y_pred/y_true
        self.history[f'{prefix}step_y_pred'].extend(y_pred)
        self.history[f'{prefix}step_y_true'].extend(y_true)

        return loss, acc, f1_score
    
    def _epoch_end_logic(self, outs, prefix=''):
        if 'val' in prefix:
            func_idx = 1
        elif 'test' in prefix:
            func_idx = 2
        else:
            func_idx = 0

        # log epoch metrics
        loss_epoch = torch.tensor([x[f'{prefix}loss'] for x in outs], dtype=float).mean()
        self._store_metric(f'{prefix}loss', loss_epoch)

        acc_epoch = self.acc_epoch_funcs[func_idx].compute()
        self.acc_epoch_funcs[func_idx].reset()
        self._store_metric(f'{prefix}acc', acc_epoch)

        f1_score_epoch = self.f1_score_epoch_funcs[func_idx].compute()
        self.f1_score_epoch_funcs[func_idx].reset()
        self._store_metric(f'{prefix}f1_score', f1_score_epoch)

        return loss_epoch, acc_epoch, f1_score_epoch

    def training_step(self, batch, batch_idx):
        prefix = ''
        loss, acc, f1_score = self._step_logic(batch, prefix=prefix)
        return {f'{prefix}loss':loss, f'{prefix}acc':acc, f'{prefix}f1_score':f1_score}
    
    def training_epoch_end(self, outs) -> None:
        # 'outs' argument here contains values from what was returned from training_step()
        _,_,_ = self._epoch_end_logic(outs)

    def validation_step(self, batch, batch_idx):
        prefix = 'val_'
        loss, acc, f1_score = self._step_logic(batch, prefix=prefix)
        return {f'{prefix}loss':loss, f'{prefix}acc':acc, f'{prefix}f1_score':f1_score}

    def validation_epoch_end(self, outs) -> None:
        _,_,_ = self._epoch_end_logic(outs, prefix='val_')

    def test_step(self, batch, batch_idx):
        prefix = 'test_'
        loss, acc, f1_score = self._step_logic(batch, prefix=prefix)
        return {f'{prefix}loss':loss, f'{prefix}acc':acc, f'{prefix}f1_score':f1_score}

    def test_epoch_end(self, outs) -> None:
        _,_,_ = self._epoch_end_logic(outs, prefix='test_')

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


# init the autoencoder
cnnLightning = MNISTClassifier(the_device='cuda')
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [50000, 10000])
train_loader = utils.data.DataLoader(train_dataset, batch_size=32, num_workers=4)
val_loader = utils.data.DataLoader(val_dataset, batch_size=32, num_workers=4)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(log_every_n_steps=1, num_sanity_val_steps=0, max_epochs=3, accelerator='gpu', fast_dev_run=False)
trainer.fit(model=cnnLightning, train_dataloaders=train_loader, val_dataloaders=val_loader) # , val_dataloaders=val_loader

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | conv1    | Conv2d     | 320   
1 | conv2    | Conv2d     | 18.5 K
2 | dropout1 | Dropout2d  | 0     
3 | dropout2 | Dropout2d  | 0     
4 | fc1      | LazyLinear | 0     
5 | fc2      | LazyLinear | 0     
----------------------------------------
18.8 K    Trainable params
0         Non-trainable params
18.8 K    Total params
0.075     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
history = trainer.model.history.copy()
for metric, tensors_list in history.items():
    if isinstance(tensors_list, list) and len(tensors_list) > 0 and isinstance(tensors_list[0], torch.Tensor):
        history[metric] = np.array([np.array(tens.detach().cpu()) for tens in tensors_list]).round(5)

In [None]:
# debugging cell: dummy example for demonstration purposes:
# the code in this cell can be found in forward()

# len(hier_y_probs) == num_hierarchy_output_layers, 
# while each element is 2D tensor of shape (batch_size, num_classes of the i_th output layer in hierarchy)
hier_y_pred = []
for tensor_output in demo_model(a):
    hier_y_pred.append(F.softmax(tensor_output, dim=1))
hier_y_pred

[tensor([[0.5039, 0.4961],
         [0.5040, 0.4960],
         [0.5040, 0.4960],
         [0.5040, 0.4960]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.2513, 0.2444, 0.2263, 0.2780],
         [0.2514, 0.2444, 0.2262, 0.2780],
         [0.2513, 0.2444, 0.2263, 0.2780],
         [0.2513, 0.2444, 0.2262, 0.2780]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.2561, 0.2539, 0.2301, 0.2599],
         [0.2561, 0.2538, 0.2302, 0.2599],
         [0.2561, 0.2538, 0.2301, 0.2599],
         [0.2561, 0.2539, 0.2301, 0.2599]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3410, 0.3629, 0.2960],
         [0.3410, 0.3629, 0.2961],
         [0.3410, 0.3629, 0.2960],
         [0.3410, 0.3630, 0.2960]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3407, 0.3076, 0.3518],
         [0.3407, 0.3075, 0.3518],
         [0.3407, 0.3075, 0.3518],
         [0.3407, 0.3075, 0.3517]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3223, 0.3448, 0.3329],
         [0.3223, 0.3448, 0.3329],
         [0.3223, 0.3449, 0.3329],
         [0.3

In [None]:
# debugging cell: dummy example for demonstration purposes:
# the code in this cell can be found in train_step() (or actually, in _step_logic())

# x, y = batch
# hier_y_pred = self(x)
y = torch.tensor([1,2,5,8]) # remove this in train_step()
# y now has shape (batch_size, len(hier_y_pred))
# in other words, each row now consists of `len(hier_y_pred)` output layers' labels of 1 sample
y = torch.tensor([labelToHierarchy[i] for i in range(len(y))], dtype=int) # alternatively, range(y.size()[0])
print(y)
# transpose y to take each row (batch of labels) with its corresponding hier_y_pred row (batch of predicted labels)
# in other words, each row now consists of `batch_size` labels of the i_th output layer
y = y.T
print(y)
print()
# softmax_output_layer elements each has shape (batch_size, num_classes of the i_th output layer in hierarchy)
# side note: the wording of "i_th output layers" refers to the strings mentioned in 
# output_order argument of HierarchalModel() used in create_hnn_model_arch()
for i, softmax_output_layer in enumerate(hier_y_pred):
    # 'cur' refers to current output layer
    y_cur = y[i]
    y_pred_cur = softmax_output_layer
    print(y_cur)
    print(y_pred_cur)
    break

tensor([[0, 0, 3, 2, 2, 2],
        [0, 1, 3, 0, 2, 2],
        [0, 1, 3, 1, 2, 2],
        [1, 3, 0, 2, 0, 2]])
tensor([[0, 0, 0, 1],
        [0, 1, 1, 3],
        [3, 3, 3, 0],
        [2, 0, 1, 2],
        [2, 2, 2, 0],
        [2, 2, 2, 2]])

tensor([0, 0, 0, 1])
tensor([[0.5039, 0.4961],
        [0.5040, 0.4960],
        [0.5040, 0.4960],
        [0.5040, 0.4960]], grad_fn=<SoftmaxBackward0>)


In [None]:
# this function would've been useful if in training_step(), we returned 'y' as ohe (e.g., [0,0,1]) instead of integer labels (e.g., 2)
def ohe(value, size): # ohe == one hot encode
    identity = torch.eye(size)
    one_hot = identity[value]
    if (value == -1):
        # then we want the ohe list to be all zeros
        one_hot[-1] = 0
    # uncomment this line if you want to return 2D tensor of shape (1, size)
    # return one_hot.view(-1, size)
    return one_hot

classToHierarchy = {
    '00. selfies' : [ohe(0, 2), # diversity of colors: many or few
                     ohe(0, 4), # many colors: selfies, memes (fmm or emm), or eGreetingsAndMisc (+1 if none of them)
                     ohe(-1, 4), # few colors: socialMedia (fsm or esm), fTxtMssgs, academic (photos or digital) (+1 if none of them)
                     ohe(-1, 3), # meme type: fmm or emm (+1 if none of them)
                     ohe(-1, 3), # social media type: fsm or esm (+1 if none of them)
                     ohe(-1, 3)], # academic type: photos or digital (+1 if none of them)
    '10. fmemes' : [ohe(0, 2), ohe(1, 4), ohe(-1, 4), ohe(0, 3), ohe(-1, 3), ohe(-1, 3)],
    '20. ememes' : [ohe(0, 2), ohe(1, 4), ohe(-1, 4), ohe(1, 3), ohe(-1, 3), ohe(-1, 3)],
    '30. fSocialMedia' : [ohe(1, 2), ohe(-1, 4), ohe(0, 4), ohe(-1, 3), ohe(0, 3), ohe(-1, 3)],
    '40. eSocialMedia' : [ohe(1, 2), ohe(-1, 4), ohe(0, 4), ohe(-1, 3), ohe(1, 3), ohe(-1, 3)],
    '50. fTxtMssgs' : [ohe(1, 2), ohe(-1, 4), ohe(1, 4), ohe(-1, 3), ohe(-1, 3), ohe(-1, 3)],
    '70. eGreetingAndMisc' : [ohe(0, 2), ohe(2, 4), ohe(-1, 4), ohe(-1, 3), ohe(-1, 3), ohe(-1, 3)],
    '81. academicPhotos' : [ohe(1, 2), ohe(-1, 4), ohe(2, 4), ohe(-1, 3), ohe(-1, 3), ohe(0, 3)],
    '82. academicDigital' : [ohe(1, 2), ohe(-1, 4), ohe(2, 4), ohe(-1, 3), ohe(-1, 3), ohe(1, 3)],
}
classToHierarchy['00. selfies']

hierarchical model with old (not working) manual logging

In [None]:

import pytorch_lightning as pl
from torch.nn import functional as F
import torchmetrics

class HierarchalModelPL(pl.LightningModule):
   def __init__(self, hierarchical_model, num_classes_per_layer:list, loss_weights=None, dummy_input_size=(32, 3, 306, 306), the_device='cpu'):
      super(HierarchalModelPL, self).__init__()
      self.hierarchical_model = hierarchical_model # .to(the_device)
      self.num_output_layers = len(num_classes_per_layer)
      self.loss_weights = [1]*self.num_output_layers if loss_weights is None else loss_weights
      self.num_output_layers = self.num_output_layers
      
      # useful for writing computational graph in tensorboard, summary(), etc
      self.example_input_array = torch.randn(*dummy_input_size) # , device=the_device

      self.the_device = the_device if 'cpu' in the_device else 'cuda'
      self.metrics = ['acc', 'f1_score']
      def _create_metric_func(metric_type:str, num_classes):
         if 'acc' in metric_type.lower():
               return torchmetrics.Accuracy(task='multiclass', 
                                          num_classes=num_classes, 
                                          average="micro").to(the_device)
         elif 'f1' in metric_type.lower():
               return torchmetrics.F1Score(task="multiclass", 
                                          num_classes=num_classes, 
                                          average=None).to(the_device)
      
      #                                         step/epoch vvv
      # max indexing possible: self.metric_funcs['f1_score'][1][num_output_layers-1][2]
      #                                       softmax output layers ^^^            ^^^  train/val/test 
      # output example of metric_funcs['acc']
      # (note: MCA == MulticlassAccuracy()):
      # {'acc': [ 
      #   [ #step
      #    [MCA, MCA, MCA], # softmax_output_layer_0 ; train/val/test
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA], # ...
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA] # # softmax_output_layer_5 ; train/val/test
      #   ],
      #   [ #epoch
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA],
      #    [MCA, MCA, MCA]
      #   ]
      # ]
      
      self.metric_funcs = {}
      for metric in self.metrics:
         # i --> step/epoch, j --> num_output_layers-1, k --> train/val/test
         self.metric_funcs[metric] = [[[_create_metric_func(metric, num_classes_per_layer[j]) for k in range(3)] 
                                       for j in range(self.num_output_layers)] 
                                       for i in range(2)]
      
      # history_per_output_layer = {
      #    # per epoch
      #    'loss':[],
      #    'acc':[],
      #    'f1_score':[],
      #    'val_loss':[],
      #    'val_acc':[],
      #    'val_f1_score':[],
      #    'test_loss':[],
      #    'test_acc':[],
      #    'test_f1_score':[],
      #    # per step
      #    'step_loss':[],
      #    'step_acc':[],
      #    'step_f1_score':[],
      #    'val_step_loss':[],
      #    'val_step_acc':[],
      #    'val_step_f1_score':[],
      #    'test_step_loss':[],
      #    'test_step_acc':[],
      #    'test_step_f1_score':[],
      #    # for tracking pred/true values for train/val/test sets
      #    'step_y_pred': [],
      #    'step_y_true': [],
      #    'val_step_y_pred': [],
      #    'val_step_y_true': [],
      #    'test_step_y_pred': [],
      #    'test_step_y_true': [],
      # }

      history_per_output_layer = {
         # per epoch
         'loss':[],
         'acc':[],
         'f1_score':[],
         # these 2 will be replaced with step_y_pred and step_y_true respectively
         'y_pred':[],
         'y_true':[],
      }

      def pre_suf_fix_dict(dict, fix_val='output_layer', prefix=True):
         return {f"{fix_val}_{key}" if prefix else f"{key}_{fix_val}" : value for key, value in dict.items()}

      # to create val/test prefixed keys
      history_per_output_layer.update(
         **pre_suf_fix_dict(history_per_output_layer, fix_val='val', prefix=True),
         **pre_suf_fix_dict(history_per_output_layer, fix_val='test', prefix=True),
      )
      # to create step_ prefixed keys
      history_per_output_layer.update(
         **pre_suf_fix_dict(history_per_output_layer, fix_val='step', prefix=True),
      )
      # removing y_pred, y_true and keeping step_y_pred and step_y_true
      history_per_output_layer.pop('y_pred')
      history_per_output_layer.pop('y_true')
      
      # fix_val == suffix without "_"; to provide flexibility in using fix_val as prefix as well (if we wanted to!)
      self.history_layers_fix_val = 'output_layer' 

      # pre_suf_fix_dict function is used to create unique key names to merge the `num_output_layers` histories into one history dictionary
      history_ol_lists = [pre_suf_fix_dict(history_per_output_layer, 
                                           fix_val=f'{self.history_layers_fix_val}_{i}', 
                                           prefix=False) 
                           for i in range(self.num_output_layers)] # ol == output layers
      
      # 'loss' here is for the final loss calculated by the loss aggregator function which uses all `num_output_layers` losses
      self.history = {'loss':[], 'val_loss':[], 'test_loss':[]}
      self.history.update(**pre_suf_fix_dict(self.history, 'step', prefix=True))
      # merging the `num_output_layers` histories into one history dictionary
      for hist_dict in history_ol_lists:
         self.history.update(**hist_dict)


   def forward(self, x):
      # len(hier_y_probs) == num_hierarchy_output_layers, 
      # while each element is 2D tensor of shape (batch_size, num_classes of the i_th output layer in hierarchy)
      hier_y_pred = []
      for tensor_output in self.hierarchical_model(x):
         hier_y_pred.append(F.softmax(tensor_output, dim=1))
      hier_y_pred
      return hier_y_pred
   
   def training_step(self, batch, batch_idx):
      ds_prefix = ''
      metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
      # add other key:value pairs here or pass the entire metrics_dict if you want. just make sure 'loss' key is present
      return metrics_dict
    
   def training_epoch_end(self, outputs) -> None:
      # 'outputs' argument here contains values from what was returned from training_step()
      ds_prefix = ''
      self._epoch_end_logic(outputs, ds_prefix)

   def validation_step(self, batch, batch_idx):
      ds_prefix = 'val_'
      metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
      return metrics_dict
   
   def validation_epoch_end(self, outputs) -> None:
      ds_prefix = 'val_'
      self._epoch_end_logic(outputs, ds_prefix)
   
   def test_step(self, batch, batch_idx):
      ds_prefix = 'test_'
      metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
      return metrics_dict
   
   def test_epoch_end(self, outputs) -> None:
      ds_prefix = 'test_'
      self._epoch_end_logic(outputs, ds_prefix=ds_prefix)
   

   def _store_metric(self, metric_name, metric_val):
      try:
         self.log(metric_name, metric_val)
      except Exception:
         # then metric_val is not scalar, then this means:
         # metric_name is an f1 score metric 
         # so, we just take the mean value
         self.log(metric_name, metric_val.mean())
      # putting this here ensures we don't include the mean of f1 score, 
      # but rather the tensor of shape (1, num_classes)
      self.history[metric_name].append(metric_val)

   def _step_logic(self, batch, ds_prefix=''):
      # ds_prefix == dataset_prefix
      if 'val' in ds_prefix:
         ds_type_idx = 1
      elif 'test' in ds_prefix:
         ds_type_idx = 2
      else:
         ds_type_idx = 0
      ds_prefix = 'step_' + ds_prefix

      x, y = batch
      hier_y_pred = self(x)

      # y now has shape (batch_size, len(hier_y_pred))
      # in other words, each row now consists of `len(hier_y_pred)` output layers' labels of 1 sample
      # side note: len(hier_y_pred) == self.num_output_layers
      y = torch.tensor([labelToHierarchy[int(y[i])] for i in range(len(y))], dtype=int, device=self.the_device) # alternatively, range(y.size()[0])
      # transpose y to take each row (batch of labels) with its corresponding hier_y_pred row (batch of predicted labels)
      # in other words, each row now consists of `batch_size` labels of the i_th output layer
      y = y.T
      # softmax_output_layer elements each has shape (batch_size, num_classes of the i_th output layer in hierarchy)
      # side note: the wording of "i_th output layers" refers to the strings mentioned in 
      # output_order argument of HierarchalModel() used in create_hnn_model_arch()

      losses_dict = {}
      other_metrics_dict = {}
      for out_layer_idx, softmax_output_layer in enumerate(hier_y_pred): 
         # 'cur' refers to current output layer
         y_cur = y[out_layer_idx]
         y_pred_cur = softmax_output_layer.to(self.the_device)
         ol_suffix = f'_{self.history_layers_fix_val}_{out_layer_idx}'

         # log step metrics

         loss = F.cross_entropy(y_pred_cur, y_cur)
         loss_full_metric_name = f'{ds_prefix}loss{ol_suffix}'
         self._store_metric(loss_full_metric_name, loss)
         losses_dict[loss_full_metric_name] = loss

         for metric_name in self.metrics:
            # recall the indexing of self.metric_funcs:
            # metric_name, step/epoch, num_output_layers-1, train/val/test
            metric_val = self.metric_funcs[metric_name][0][out_layer_idx][ds_type_idx](y_pred_cur, y_cur)
            full_metric_name = f'{ds_prefix}{metric_name}{ol_suffix}'
            self._store_metric(full_metric_name, metric_val)
            self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].update(y_pred_cur, y_cur)
            other_metrics_dict[full_metric_name] = metric_val
      
         f1_score = self.metric_funcs[metric_name][0][out_layer_idx][ds_type_idx](y_pred_cur, y_cur)
         self._store_metric(f'{ds_prefix}{metric_name}{ol_suffix}', f1_score)
         self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].update(y_pred_cur, y_cur)

         # storing y_pred/y_true
         self.history[f'{ds_prefix}y_pred{ol_suffix}'].extend(y_pred_cur)
         self.history[f'{ds_prefix}y_true{ol_suffix}'].extend(y_cur)

      final_loss = self.metric_reduce_fx(losses_dict.values(), 'weighted_sum')
      self._store_metric(f'{ds_prefix}loss', final_loss)

      # Important note: 'loss' key must be present, or else you'll get this error:
      # MisconfigurationException: In automatic_optimization, 
      # when `training_step` returns a dict, the 'loss' key needs to be present
      # side note: add `.update(losses_dict)` and `.update(other_metrics_dict)` 
      # if you want to directly use other metrics in "..._epoch_end()" methods
      return {f'loss' : final_loss} 

   def _epoch_end_logic(self, outputs, ds_prefix=''):
      if 'val' in ds_prefix:
         ds_type_idx = 1
      elif 'test' in ds_prefix:
         ds_type_idx = 2
      else:
         ds_type_idx = 0

      # log epoch metrics
      final_loss_epoch = torch.tensor([x[f'loss'] for x in outputs], dtype=float).mean()
      self._store_metric(f'{ds_prefix}loss', final_loss_epoch)

      for out_layer_idx in range(self.num_output_layers):
         ol_suffix = f'_{self.history_layers_fix_val}_{out_layer_idx}'

         for metric_name in self.metrics:
            # recall: '1' for accessing epoch func (not step func)
            metric_val_epoch = self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].compute()
            self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].reset()
            full_metric_name = f'{ds_prefix}{metric_name}{ol_suffix}'
            self._store_metric(full_metric_name, metric_val_epoch)
            if out_layer_idx == 0:
               print(metric_name)
               print(metric_val_epoch)
               print()
   
   def metric_reduce_fx(self, metric_list, agg_type='weighted_sum'):
      '''
      aggregates metric values from all `num_output_layers` into a single value
      side note: called "reduce_fx" as a reference to PyTorch Lightning's reduce_fx parameter found in self.log()
      '''
      if 'weighted' in agg_type.lower() and 'sum' in agg_type.lower():
         weighted_sum_val = 0
         for i, layer_metric_val in enumerate(metric_list):
            weighted_sum_val += self.loss_weights[i] * layer_metric_val
         final_val = weighted_sum_val

      return final_val

   def configure_optimizers(self):
      # make it self.hierarchical_model.parameters() if you defined other parameters in __init()__ which you don't want to optimize
      # side note: under the hood, pl automatically gets gradients 
      # from final_loss returned from training_step() and adjusts the models' branches accordingly
      # source: https://github.com/Lightning-AI/lightning/issues/2645#issuecomment-660681760
      return torch.optim.Adam(self.parameters(), lr=1e-3)


side note: for some reason setting `num_workers` argument for dataloader in MNIST dataset code above works
while doing so for this cell does not work:

In [2]:

if __name__ == "__main__":
    import pickle
    def pklLoad(fullPath):
        with open(fullPath, 'rb') as f:
            content = pickle.load(f)
        return content

    dataDir = '../dataset/'
    dataReDir = '../dataset_related/'
    classNameToNum = pklLoad('../dataset_related/imgTypeToNum.pickle')

    import numpy as np
    import pytorch_lightning as pl
    from torch.nn import functional as F
    import torchmetrics
    import torch
    import pandas as pd
    import torchvision
    from torchvision import transforms as T
    from torch import nn
    from simple_hierarchy.hierarchal_model import HierarchalModel
    from torch.utils.tensorboard import SummaryWriter
    from PIL import Image


    class CustomDatasetForV03(torch.utils.data.Dataset):
        """
        Implementation notes: 
        you must have a column consisting of image names (with extension) as index of the dataframe (example: ['img0001.jpg', 'img0002.png', ...])
        each image name has to be unique across the entire dataset, not just its class

        This class is for PyTorch, unlike CustomDataGeneratorForV02, which was for TF. 
        """
        def __init__(self, img_labels_df, class_indices, target_size, batch_size=16, transforms_function=None):
            # attributes that were in flow_from_dataframe function used previously in CustomDataGeneratorForV02. Added here for compatibility with other functions like trainModel(), etc
            self.img_labels_df = img_labels_df
            self.filenames = list(img_labels_df['filepath'])
            self.labels = list(class_indices[class_name] for class_name in img_labels_df['label']) # 'labels' is actually not found in flow_from_dataframe's __dict__, but is a synonym of "classes" attribute found in the __dict__
            self.samples = len(img_labels_df) # number of rows of the df
            self.class_indices = class_indices
            self.target_size = target_size
            self.batch_size = batch_size
            self.transforms_function = transforms_function


        def __len__(self):
            return self.samples // min(1, self.batch_size)
        
        def __getitem__(self, index):
            # note: if you pass batch_size in DataLoader() instead of batch_sampler, then __getitem__() will run for a batch_size number of times consecutively,
            #       same thing happens if you pass a batch sampler to `sampler` argument.
            #       However, if you pass a non-batch (i.e., single) sampler to `batch_sampler` argument, 
            #       __getitem__() will run once, but "index" argument will be a list of indices (of length of the batch size set by batch_sampler)
            #       (side note: don't pass a batch sampler to `batch_sampler`, as this will cause error like this:)
            #       (TypeError: 'numpy.int64' object is not iterable)
            #       (source: https://github.com/pytorch/pytorch/issues/71872#issue-1115383219)
            # source 1: https://discuss.pytorch.org/t/how-to-use-batchsampler-with-getitem-dataset/78788/4#:~:text=The%20index%20inside%20__getitem__%20will%20contain%2010%20random%20indices%2C%20which%20are%20used%20to%20create%20the%20batch%20in%3A
            # source 2: https://stackoverflow.com/questions/65231299/load-csv-and-image-dataset-in-pytorch
            if type(index) in [int, np.int64]:
                filepath = self.filenames[index]
                img = Image.open(filepath)
                if self.transforms_function is not None:
                    img = self.transforms_function(img)
                y = torch.tensor(self.labels[index], dtype=torch.int8)
                return img, y
            else:        
                indices = index
                imgs = [Image.open(self.filenames[idx]) for idx in range(len(indices))]
                if self.transforms_function is not None:
                    imgs = [self.transforms_function(img) for img in imgs]
                ys = torch.tensor([self.labels[idx] for idx in range(len(indices))], dtype=torch.int8)
                return imgs, ys

    class CustomDataLoader(torch.utils.data.DataLoader):
        def __init__(self, dataset, *args, **kwargs):
            if 'sampler' in kwargs.keys():
                kwargs.pop('batch_size', None)
                kwargs.pop('drop_last', None)
                kwargs.pop('shuffle', None)
            super().__init__(dataset, *args, **kwargs)
            self.__dict__.update(dataset.__dict__)
            self.__dict__.pop('img_labels_df')

    class CustomStratifiedSampler: # (torch.utils.data.BatchSampler)
        """Stratified Sampling

        Provides equal representation of target classes in each batch
        Source: https://github.com/ncullen93/torchsample/blob/master/torchsample/samplers.py#L22
        Found from here: https://discuss.pytorch.org/t/how-to-enable-the-dataloader-to-sample-from-each-class-with-equal-probability/911/2
        """
        def __init__(self, class_vector, batch_size, random_state=42):
            """
            Arguments
            ---------
            class_vector : torch tensor
                a vector of class labels
            batch_size : integer
                batch_size
            """
            self.n_splits = int(class_vector.size(0) / batch_size)
            self.class_vector = class_vector
            self.batch_size = batch_size
            self.steps_num = class_vector.size(0) // batch_size
            self.random_state = random_state

        def gen_sample_array(self):
            try:
                from sklearn.model_selection import StratifiedShuffleSplit
            except:
                print('Need scikit-learn for this functionality')
            import numpy as np
            s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5, random_state=self.random_state)
            X = torch.randn(self.class_vector.size(0),2)
            y = self.class_vector
            s.get_n_splits(X, y) 

            train_index, test_index = next(s.split(X, y))
            return np.hstack([train_index, test_index])

        def __iter__(self):
            return iter(self.gen_sample_array())

        def __len__(self):
            return self.steps_num
        

    from sklearn.model_selection import train_test_split
    from torch.utils.data import SequentialSampler, BatchSampler
    def get_dataloaders(classNameToNum, dataReDir, img_size=(306, 306), batch_size=16, 
                        aug_val=False, aug_test=False, aug_values=False, 
                        shuffle=True, stratify=True, drop_last=True, num_workers=0):
        """
        Args:
            classNameToNum: should map each class name to its index
            dataReDir: should be a path string of directory which has train/val/test CSVs, where each csv consists of 2 columns: 'filepath', and 'label'
            aug_val/aug_test/aug_values: should be a torchvision.transforms's Compose() function, or False

        Note: setting shuffle will not affect training; as train_data.csv was already shuffled 
        by train_test_split() in dataset_preprocessing_part_2.ipynb
        """

        # this cell is from dataset_preprocessing_part_2.ipynb and will be used in most model ipynb files to prepare images
        train_data = pd.read_csv(dataReDir+'train_datav01.csv')
        val_data = pd.read_csv(dataReDir+'val_datav01.csv')
        test_data = pd.read_csv(dataReDir+'test_datav01.csv')

        default_transformations = T.Compose([T.Resize(img_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST), T.ToTensor()])
        if aug_values == False:
            aug_values = default_transformations
        if aug_val == False:
            aug_val = default_transformations
        if aug_test == False:
            aug_test = default_transformations


        # batch_size here only affects what is returned by __len__()
        train_dataset = CustomDatasetForV03(train_data, classNameToNum, img_size, batch_size, aug_values)
        val_dataset = CustomDatasetForV03(val_data, classNameToNum, img_size, batch_size, aug_val)
        test_dataset = CustomDatasetForV03(test_data, classNameToNum, img_size, batch_size, aug_test)

        if stratify:
            train_stratify = CustomStratifiedSampler(torch.tensor(train_dataset.labels), batch_size, random_state=42)
            val_sampler = BatchSampler(SequentialSampler(val_dataset), batch_size=batch_size, drop_last=drop_last)
            test_sampler = BatchSampler(SequentialSampler(test_dataset), batch_size=batch_size, drop_last=drop_last)
        else:
            train_stratify = None
            val_sampler = None
            test_sampler = None

        train_dataloader = CustomDataLoader(train_dataset, batch_size=batch_size, drop_last=drop_last, 
                                            sampler=train_stratify, shuffle=shuffle, num_workers=num_workers)
        # always set to False if you'll use val_gen/test_gen.filenames list later to analyze the classification results of the val/test set
        val_dataloader = CustomDataLoader(val_dataset, batch_size=batch_size, drop_last=drop_last, 
                                            sampler=val_sampler, shuffle=False, num_workers=num_workers)
        test_dataloader = CustomDataLoader(test_dataset, batch_size=batch_size, drop_last=drop_last,
                                            sampler=test_sampler, shuffle=False, num_workers=num_workers) 
    
        spe = train_dataset.samples // batch_size
        vsteps = val_dataset.samples // batch_size
        tsteps = test_dataset.samples // batch_size
        print(train_dataset.samples, 'train images')
        print(val_dataset.samples, 'val images')
        print('so, #steps per epoch (if sampler is not batch):')
        print(train_dataset.samples + val_dataset.samples)
        print()
        print(spe, 'training steps')
        print(vsteps, 'validation steps')
        print('so #steps per epoch (if sampler is batch):')
        print(spe + vsteps)
        print()
        print(test_dataset.samples, 'test images')
        print(tsteps, 'testing steps')

        return train_dataloader, val_dataloader, test_dataloader, spe, vsteps, tsteps


    cgs = { # cgs <==> class groups
            "colorDiversity": ("colorDiversity", 2), # many or few
            "manyColors": ("manyColors", 4), # selfies, memes (e or f), or eGreetingsAndMisc (+1 if none of them)
            "fewColors": ("fewColors", 4), # socialMedia (e or f), fTxtMssgs, academic (photos or digital) (+1 if none of them)
            "memes": ("memes", 3), # e or f (+1 if none of them)
            "socialMedia": ("socialMedia", 3), # e or f (+1 if none of them)
            "academic": ("academic", 3) # photos or academic (+1 if none of them)
        }
    cs = { # cs <==> classes
        "selfies": ("selfies", 1),
        "fmemes": ("fmemes", 1),
        "ememes": ("ememes", 1),
        "fSocialMedia": ("fSocialMedia", 1),
        "eSocialMedia": ("eSocialMedia", 1),
        "fTxtMssgs": ("fTxtMssgs", 1),
        "eGreetingsAndMisc": ("eGreetingsAndMisc", 1),
        "academicPhotos": ("academicPhotos", 1),
        "academicDigital": ("academicDigital", 1),
    }
    # smaller hierarchy
    hierarchy = {
        cgs["colorDiversity"] : [cgs["manyColors"], cgs["fewColors"]],
        cgs["manyColors"] : [cgs["memes"]],
        cgs["fewColors"] : [cgs["socialMedia"], cgs["academic"]],
    }
    # mapping flat labels to hierarchical labels
    labelToHierarchy = {
        0 : [
                0, # diversity of colors: many or few (so classes: 0,1)
                0, # many colors: selfies, memes (fmm or emm), or eGreetingsAndMisc (+1 if none of them) (so classes: 0,1,2,3)
                3, # few colors: socialMedia (fsm or esm), fTxtMssgs, academic (photos or digital) (+1 if none of them) (so classes: 0,1,2,3)
                2, # meme type: fmm or emm (+1 if none of them) (so classes: 0,1,2)
                2, # social media type: fsm or esm (+1 if none of them) (so classes: 0,1,2)
                2, # academic type: photos or digital (+1 if none of them) (so classes: 0,1,2)
            ], 
        1 : [0, 1, 3, 0, 2, 2],
        2 : [0, 1, 3, 1, 2, 2],
        3 : [1, 3, 0, 2, 0, 2],
        4 : [1, 3, 0, 2, 1, 2],
        5 : [1, 3, 1, 2, 2, 2],
        6 : [0, 2, 3, 2, 2, 2],
        7 : [1, 3, 2, 2, 2, 0],
        8 : [1, 3, 2, 2, 2, 1],
    }
    # mapping hierarchical labels (integers) to hierarchical names (strings)
    hierLabelToClassName = {
        # diversity of colors: many or few (so classes: 0,1)
        0 : ['manyColors', 'fewColors'], 
        # many colors: selfies, memes (fmm or emm), or eGreetingsAndMisc (+1 if none of them) (so classes: 0,1,2,3)
        1 : ['00. selfies', 'memes', '70. eGreetingAndMisc', 'none'],
        # few colors: socialMedia (fsm or esm), fTxtMssgs, academic (photos or digital) (+1 if none of them) (so classes: 0,1,2,3)
        2 : ['socialMedia', '50. fTxtMssgs', 'academic', 'none'],
        # meme type: fmm or emm (+1 if none of them) (so classes: 0,1,2)
        3 : ['10. fmemes', '20. ememes', 'none'],
        # social media type: fsm or esm (+1 if none of them) (so classes: 0,1,2)
        4 : ['30. fSocialMedia', '40. fSocialMedia', 'none'],
        # academic type: photos or digital (+1 if none of them) (so classes: 0,1,2)
        5 : ['81. academicPhotos', '82. academicDigital', 'none'],
    }

    def create_hnn_model_arch(model_version:str, img_size:tuple, class_groups:dict, classes:dict, hierarchy:dict):
        base_model = nn.Sequential(
            # note that padding is 0 to re-produce TF's implementation of m01, where the default value for 'padding' argument is 'valid' which means no padding
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=0), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Flatten(),
            nn.LazyLinear(512),
            nn.ReLU(),
            # unlike TF implementation of m01, we'll not include this layer here, rather we'll add it later
            # nn.Softmax(dim=1) 
        )

        # these are the indepdent layers of parent and children
        model_layers = [
            # the output of this layer is feed forward from parent to child
            nn.Linear(512, 128),
            nn.Linear(128, 64),
            # we do not include the layer below, as it is automatically created at the end of each parent/child. check this for details:
            # https://github.com/rajivsarvepalli/SimpleHierarchy/blob/8e4c29f334928f43509b2d328b11b4a83f2d2af6/src/simple_hierarchy/hierarchal_model.py#L173
            # nn.Linear(64, num_classes), 
        ]

        # 512 is the output size of our base model
        # 512 is the input size of our additional indepdent layers (called model_layers)
        # 64 is the output size of our additional indepdent layers (called model_layers) (excluding the final automatically-created hidden layer)
        # 128 is the output size of third to last additional indepdent layer to feed (note: we say 'second to last' if we neglect talking about the final automatically-created hidden layer)
        # forward from parent to child (with concatenation)
        size = (512,512,64,128)
        # all 2 layers are distinct for each grouping of classes of model_layers (actually 5 layers if you add the automatically-created hidden layer)
        k = 2
        # we want to feed from the fourth to last layer (from parent to child (with concatenation)) (note: we say 'third to last' if we neglect talking about the final automatically-created hidden layer)
        feed_from = 1
        output_order = [*class_groups.values()]
        print(f'output_order:\n{output_order}')
        idx_to_class_name = {k : v for k, v in enumerate([*class_groups.keys(), *classes.keys()])}
        model = HierarchalModel(hierarchy=hierarchy, size=size, output_order=output_order, 
                                base_model=base_model, model=model_layers, k=k, feed_from=feed_from)
        
        a = torch.rand(3,*img_size).unsqueeze(0)
        # running an arbitrary forward pass to initialze weights/params (since LazyLinear was used)
        model(a) 

        writer = SummaryWriter(f"../models/hnn_model_v{model_version}_tb_graphs")
        writer.add_graph(model, a)
        writer.close()

        return model, idx_to_class_name
    
    img_size = (306,306)
    hier_model, idx_to_class_name = create_hnn_model_arch("03", img_size, cgs, cs, hierarchy)
    num_classes_per_layer = [num_outputs for layer_name, num_outputs in hier_model.output_order]


    class HierarchalModelPL(pl.LightningModule):
        def __init__(self, hierarchical_model, num_classes_per_layer:list, loss_weights=None, dummy_input_size=(32, 3, 306, 306), the_device='cpu'):
            super(HierarchalModelPL, self).__init__()
            self.hierarchical_model = hierarchical_model # .to(the_device)
            self.num_output_layers = len(num_classes_per_layer)
            self.loss_weights = [1]*self.num_output_layers if loss_weights is None else loss_weights
            self.num_output_layers = self.num_output_layers
            
            # useful for writing computational graph in tensorboard, summary(), etc
            self.example_input_array = torch.randn(*dummy_input_size) # , device=the_device

            self.the_device = the_device if 'cpu' in the_device else 'cuda'
            self.metrics = ['acc', 'f1_score']
            def _create_metric_func(metric_type:str, num_classes):
                if 'acc' in metric_type.lower():
                    return torchmetrics.Accuracy(task='multiclass', 
                                                num_classes=num_classes, 
                                                average="micro").to(the_device)
                elif 'f1' in metric_type.lower():
                    return torchmetrics.F1Score(task="multiclass", 
                                                num_classes=num_classes, 
                                                average=None).to(the_device)
            
            #                                         step/epoch vvv
            # max indexing possible: self.metric_funcs['f1_score'][1][num_output_layers-1][2]
            #                                       softmax output layers ^^^            ^^^  train/val/test 
            # output example of metric_funcs['acc']
            # (note: MCA == MulticlassAccuracy()):
            # {'acc': [ 
            #   [ #step
            #    [MCA, MCA, MCA], # softmax_output_layer_0 ; train/val/test
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA], # ...
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA] # # softmax_output_layer_5 ; train/val/test
            #   ],
            #   [ #epoch
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA],
            #    [MCA, MCA, MCA]
            #   ]
            # ]
            
            self.metric_funcs = {}
            for metric in self.metrics:
                # i --> step/epoch, j --> num_output_layers-1, k --> train/val/test
                self.metric_funcs[metric] = [[[_create_metric_func(metric, num_classes_per_layer[j]) for k in range(3)] 
                                            for j in range(self.num_output_layers)] 
                                            for i in range(2)]
            
            # history_per_output_layer = {
            #    # per epoch
            #    'loss':[],
            #    'acc':[],
            #    'f1_score':[],
            #    'val_loss':[],
            #    'val_acc':[],
            #    'val_f1_score':[],
            #    'test_loss':[],
            #    'test_acc':[],
            #    'test_f1_score':[],
            #    # per step
            #    'step_loss':[],
            #    'step_acc':[],
            #    'step_f1_score':[],
            #    'val_step_loss':[],
            #    'val_step_acc':[],
            #    'val_step_f1_score':[],
            #    'test_step_loss':[],
            #    'test_step_acc':[],
            #    'test_step_f1_score':[],
            #    # for tracking pred/true values for train/val/test sets
            #    'step_y_pred': [],
            #    'step_y_true': [],
            #    'val_step_y_pred': [],
            #    'val_step_y_true': [],
            #    'test_step_y_pred': [],
            #    'test_step_y_true': [],
            # }

            history_per_output_layer = {
                # per epoch
                'loss':[],
                'acc':[],
                'f1_score':[],
                # these 2 will be replaced with step_y_pred and step_y_true respectively
                'y_pred':[],
                'y_true':[],
            }

            def pre_suf_fix_dict(dict, fix_val='output_layer', prefix=True):
                return {f"{fix_val}_{key}" if prefix else f"{key}_{fix_val}" : value for key, value in dict.items()}

            # to create val/test prefixed keys
            history_per_output_layer.update(
                **pre_suf_fix_dict(history_per_output_layer, fix_val='val', prefix=True),
                **pre_suf_fix_dict(history_per_output_layer, fix_val='test', prefix=True),
            )
            # to create step_ prefixed keys
            history_per_output_layer.update(
                **pre_suf_fix_dict(history_per_output_layer, fix_val='step', prefix=True),
            )
            # removing y_pred, y_true and keeping step_y_pred and step_y_true
            history_per_output_layer.pop('y_pred')
            history_per_output_layer.pop('y_true')
            
            # fix_val == suffix without "_"; to provide flexibility in using fix_val as prefix as well (if we wanted to!)
            self.history_layers_fix_val = 'output_layer' 

            # pre_suf_fix_dict function is used to create unique key names to merge the `num_output_layers` histories into one history dictionary
            history_ol_lists = [pre_suf_fix_dict(history_per_output_layer, 
                                                fix_val=f'{self.history_layers_fix_val}_{i}', 
                                                prefix=False) 
                                for i in range(self.num_output_layers)] # ol == output layers
            
            # 'loss' here is for the final loss calculated by the loss aggregator function which uses all `num_output_layers` losses
            self.history = {'loss':[], 'val_loss':[], 'test_loss':[]}
            self.history.update(**pre_suf_fix_dict(self.history, 'step', prefix=True))
            # merging the `num_output_layers` histories into one history dictionary
            for hist_dict in history_ol_lists:
                self.history.update(**hist_dict)


        def forward(self, x):
            # len(hier_y_probs) == num_hierarchy_output_layers, 
            # while each element is 2D tensor of shape (batch_size, num_classes of the i_th output layer in hierarchy)
            hier_y_pred = []
            for tensor_output in self.hierarchical_model(x):
                hier_y_pred.append(F.softmax(tensor_output, dim=1))
            hier_y_pred
            return hier_y_pred
        
        def training_step(self, batch, batch_idx):
            ds_prefix = ''
            metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
            # add other key:value pairs here or pass the entire metrics_dict if you want. just make sure 'loss' key is present
            return metrics_dict
            
        def training_epoch_end(self, outputs) -> None:
            # 'outputs' argument here contains values from what was returned from training_step()
            ds_prefix = ''
            self._epoch_end_logic(outputs, ds_prefix)

        def validation_step(self, batch, batch_idx):
            ds_prefix = 'val_'
            metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
            return metrics_dict
        
        def validation_epoch_end(self, outputs) -> None:
            ds_prefix = 'val_'
            self._epoch_end_logic(outputs, ds_prefix)
        
        def test_step(self, batch, batch_idx):
            ds_prefix = 'test_'
            metrics_dict = self._step_logic(batch, ds_prefix=ds_prefix)
            return metrics_dict
        
        def test_epoch_end(self, outputs) -> None:
            ds_prefix = 'test_'
            self._epoch_end_logic(outputs, ds_prefix=ds_prefix)
        

        def _store_metric(self, metric_name, metric_val):
            try:
                self.log(metric_name, metric_val)
            except Exception:
                # then metric_val is not scalar, then this means:
                # metric_name is an f1 score metric 
                # so, we just take the mean value
                self.log(metric_name, metric_val.mean())
            # putting this here ensures we don't include the mean of f1 score, 
            # but rather the tensor of shape (1, num_classes)
            self.history[metric_name].append(metric_val)

        def _step_logic(self, batch, ds_prefix=''):
            # ds_prefix == dataset_prefix
            if 'val' in ds_prefix:
                ds_type_idx = 1
            elif 'test' in ds_prefix:
                ds_type_idx = 2
            else:
                ds_type_idx = 0
            ds_prefix = 'step_' + ds_prefix

            x, y = batch
            hier_y_pred = self(x)

            # y now has shape (batch_size, len(hier_y_pred))
            # in other words, each row now consists of `len(hier_y_pred)` output layers' labels of 1 sample
            # side note: len(hier_y_pred) == self.num_output_layers
            y = torch.tensor([labelToHierarchy[int(y[i])] for i in range(len(y))], dtype=int, device=self.the_device) # alternatively, range(y.size()[0])
            # transpose y to take each row (batch of labels) with its corresponding hier_y_pred row (batch of predicted labels)
            # in other words, each row now consists of `batch_size` labels of the i_th output layer
            y = y.T
            # softmax_output_layer elements each has shape (batch_size, num_classes of the i_th output layer in hierarchy)
            # side note: the wording of "i_th output layers" refers to the strings mentioned in 
            # output_order argument of HierarchalModel() used in create_hnn_model_arch()

            losses_dict = {}
            other_metrics_dict = {}
            for out_layer_idx, softmax_output_layer in enumerate(hier_y_pred): 
                # 'cur' refers to current output layer
                y_cur = y[out_layer_idx]
                y_pred_cur = softmax_output_layer.to(self.the_device)
                ol_suffix = f'_{self.history_layers_fix_val}_{out_layer_idx}'

                # log step metrics

                loss = F.cross_entropy(y_pred_cur, y_cur)
                loss_full_metric_name = f'{ds_prefix}loss{ol_suffix}'
                self._store_metric(loss_full_metric_name, loss)
                losses_dict[loss_full_metric_name] = loss

                for metric_name in self.metrics:
                    # recall the indexing of self.metric_funcs:
                    # metric_name, step/epoch, num_output_layers-1, train/val/test
                    metric_val = self.metric_funcs[metric_name][0][out_layer_idx][ds_type_idx](y_pred_cur, y_cur)
                    full_metric_name = f'{ds_prefix}{metric_name}{ol_suffix}'
                    self._store_metric(full_metric_name, metric_val)
                    self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].update(y_pred_cur, y_cur)
                    other_metrics_dict[full_metric_name] = metric_val
            
                f1_score = self.metric_funcs[metric_name][0][out_layer_idx][ds_type_idx](y_pred_cur, y_cur)
                self._store_metric(f'{ds_prefix}{metric_name}{ol_suffix}', f1_score)
                self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].update(y_pred_cur, y_cur)

                # storing y_pred/y_true
                self.history[f'{ds_prefix}y_pred{ol_suffix}'].extend(y_pred_cur)
                self.history[f'{ds_prefix}y_true{ol_suffix}'].extend(y_cur)

            final_loss = self.metric_reduce_fx(losses_dict.values(), 'weighted_sum')
            self._store_metric(f'{ds_prefix}loss', final_loss)

            # Important note: 'loss' key must be present, or else you'll get this error:
            # MisconfigurationException: In automatic_optimization, 
            # when `training_step` returns a dict, the 'loss' key needs to be present
            # side note: add `.update(losses_dict)` and `.update(other_metrics_dict)` 
            # if you want to directly use other metrics in "..._epoch_end()" methods
            return {f'loss' : final_loss} 

        def _epoch_end_logic(self, outputs, ds_prefix=''):
            if 'val' in ds_prefix:
                ds_type_idx = 1
            elif 'test' in ds_prefix:
                ds_type_idx = 2
            else:
                ds_type_idx = 0

            # log epoch metrics
            final_loss_epoch = torch.tensor([x[f'loss'] for x in outputs], dtype=float).mean()
            self._store_metric(f'{ds_prefix}loss', final_loss_epoch)

            for out_layer_idx in range(self.num_output_layers):
                ol_suffix = f'_{self.history_layers_fix_val}_{out_layer_idx}'

                for metric_name in self.metrics:
                    # recall: '1' for accessing epoch func (not step func)
                    metric_val_epoch = self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].compute()
                    self.metric_funcs[metric_name][1][out_layer_idx][ds_type_idx].reset()
                    full_metric_name = f'{ds_prefix}{metric_name}{ol_suffix}'
                    self._store_metric(full_metric_name, metric_val_epoch)
        
        def metric_reduce_fx(self, metric_list, agg_type='weighted_sum'):
            '''
            aggregates metric values from all `num_output_layers` into a single value
            side note: called "reduce_fx" as a reference to PyTorch Lightning's reduce_fx parameter found in self.log()
            '''
            if 'weighted' in agg_type.lower() and 'sum' in agg_type.lower():
                weighted_sum_val = 0
                for i, layer_metric_val in enumerate(metric_list):
                    weighted_sum_val += self.loss_weights[i] * layer_metric_val
                final_val = weighted_sum_val

            return final_val

        def configure_optimizers(self):
            # make it self.hierarchical_model.parameters() if you defined other parameters in __init()__ which you don't want to optimize
            # side note: under the hood, pl automatically gets gradients 
            # from final_loss returned from training_step() and adjusts the models' branches accordingly
            # source: https://github.com/Lightning-AI/lightning/issues/2645#issuecomment-660681760
            return torch.optim.Adam(self.parameters(), lr=1e-3)

    train_gen, val_gen, test_gen, \
    spe, vsteps, tsteps = get_dataloaders(classNameToNum=classNameToNum,
                                        dataReDir=dataReDir,
                                        img_size=(306, 306), 
                                        batch_size=32,
                                        aug_val=False,
                                        aug_test=False,
                                        aug_values=False,
                                        shuffle=True,
                                        stratify=True,
                                        drop_last=True, 
                                        num_workers=2)

    model = HierarchalModelPL(hierarchical_model=hier_model, 
                            num_classes_per_layer=num_classes_per_layer, 
                            loss_weights=None, # same weight for all output layers' losses
                            dummy_input_size=(32, 3, 306, 306), 
                            the_device='cuda')
    trainer = pl.Trainer(log_every_n_steps=1, num_sanity_val_steps=0, max_epochs=2, max_steps=-1, accelerator='gpu', fast_dev_run=False)
    trainer.fit(model=model, train_dataloaders=train_gen, val_dataloaders=val_gen) # , val_dataloaders=val_loader



output_order:
[('colorDiversity', 2), ('manyColors', 4), ('fewColors', 4), ('memes', 3), ('socialMedia', 3), ('academic', 3)]
94787 train images
20311 val images
so, #steps per epoch (if sampler is not batch):
115098

2962 training steps
634 validation steps
so #steps per epoch (if sampler is batch):
3596

20312 test images
634 testing steps


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type            | Params | In sizes          | Out sizes                                             
------------------------------------------------------------------------------------------------------------------------------------
0 | hierarchical_model | HierarchalModel | 21.2 M | [32, 3, 306, 306] | [[32, 2], [32, 4], [32, 4], [32, 3], [32, 3], [32, 3]]
------------------------------------------------------------------------------------------------------------------------------------
21.2 M    Trainable params
0         Non-trainable params
21.2 M    Total params
84.640    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]