In [3]:
import os
from typing import List, Union
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl
import torchmetrics

class MNISTClassifier(pl.LightningModule):
    def __init__(self, the_device='cuda', num_classes=10, dummy_input_size=(32,3,306,306)):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.LazyLinear(128)
        self.fc2 = nn.LazyLinear(10)
        
        # useful for writing computational graph in tensorboard, summary(), etc
        # self.example_input_array = torch.randn(*dummy_input_size) 

        self.the_device = the_device

        def _create_metric_func(metric_type:str):
            if 'acc' in metric_type.lower():
                return torchmetrics.Accuracy(task='multiclass', 
                                             num_classes=num_classes, 
                                             average="micro").to(the_device)
            elif 'f1' in metric_type.lower():
                return torchmetrics.F1Score(task="multiclass", 
                                            num_classes=num_classes, 
                                            average=None).to(the_device)
        
        # for calculating metric of a step
        self.acc_step_funcs = [_create_metric_func('acc') for _ in range(3)] # for computing train, val, test results respectively
        self.f1_score_step_funcs = [_create_metric_func('f1') for _ in range(3)]
        # for calculating metric of an epoch
        self.acc_epoch_funcs = [_create_metric_func('acc') for _ in range(3)]
        self.f1_score_epoch_funcs = [_create_metric_func('f1') for _ in range(3)]
        
        self.history = {
            # per epoch
            'loss':[],
            'acc':[],
            'f1_score':[],
            'val_loss':[],
            'val_acc':[],
            'val_f1_score':[],
            'test_loss':[],
            'test_acc':[],
            'test_f1_score':[],
            # per step
            'step_loss':[],
            'step_acc':[],
            'step_f1_score':[],
            'val_step_loss':[],
            'val_step_acc':[],
            'val_step_f1_score':[],
            'test_step_loss':[],
            'test_step_acc':[],
            'test_step_f1_score':[],
            # for tracking pred/true values for train/val/test sets
            'step_y_pred': [],
            'step_y_true': [],
            'val_step_y_pred': [],
            'val_step_y_true': [],
            'test_step_y_pred': [],
            'test_step_y_true': [],
        }

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.softmax(x, dim=1)
    
    def _store_metric(self, metric_name, metric_val):
        try:
            self.log(metric_name, metric_val)
        except Exception as e:
            # then metric_val is not scalar, then this means:
            # metric_name is an f1 score metric 
            # so, we just take the mean value
            self.log(metric_name, metric_val.mean())
        # putting this here ensures we don't include the mean of f1 score, 
        # but rather the tensor of shape (1, num_classes)
        self.history[metric_name].append(metric_val)

    def _step_logic(self, batch, prefix=''):
        if 'val' in prefix:
            func_idx = 1
        elif 'test' in prefix:
            func_idx = 2
        else:
            func_idx = 0
        
        x, y_true = batch
        y_pred = self(x)
        # log step metrics
        loss = F.cross_entropy(y_pred, y_true)
        self._store_metric(f'{prefix}step_loss', loss)

        acc = self.acc_step_funcs[func_idx](y_pred, y_true)
        self._store_metric(f'{prefix}step_acc', acc)
        self.acc_epoch_funcs[func_idx].update(y_pred, y_true)
    
        f1_score = self.f1_score_step_funcs[func_idx](y_pred, y_true)
        self._store_metric(f'{prefix}step_f1_score', f1_score)
        self.f1_score_epoch_funcs[func_idx].update(y_pred, y_true)

        # storing y_pred/y_true
        self.history[f'{prefix}step_y_pred'].extend(y_pred)
        self.history[f'{prefix}step_y_true'].extend(y_true)

        return loss, acc, f1_score
    
    def _epoch_end_logic(self, outs, prefix=''):
        if 'val' in prefix:
            func_idx = 1
        elif 'test' in prefix:
            func_idx = 2
        else:
            func_idx = 0

        # log epoch metrics
        loss_epoch = torch.tensor([x[f'{prefix}loss'] for x in outs], dtype=float).mean()
        self._store_metric(f'{prefix}loss', loss_epoch)

        acc_epoch = self.acc_epoch_funcs[func_idx].compute()
        self.acc_epoch_funcs[func_idx].reset()
        self._store_metric(f'{prefix}acc', acc_epoch)

        f1_score_epoch = self.f1_score_epoch_funcs[func_idx].compute()
        self.f1_score_epoch_funcs[func_idx].reset()
        self._store_metric(f'{prefix}f1_score', f1_score_epoch)

        return loss_epoch, acc_epoch, f1_score_epoch

    def training_step(self, batch, batch_idx):
        prefix = ''
        loss, acc, f1_score = self._step_logic(batch, prefix=prefix)
        return {f'{prefix}loss':loss, f'{prefix}acc':acc, f'{prefix}f1_score':f1_score}
    
    def training_epoch_end(self, outs) -> None:
        # 'outs' argument here contains values from what was returned from training_step()
        _,_,_ = self._epoch_end_logic(outs)

    def validation_step(self, batch, batch_idx):
        prefix = 'val_'
        loss, acc, f1_score = self._step_logic(batch, prefix=prefix)
        return {f'{prefix}loss':loss, f'{prefix}acc':acc, f'{prefix}f1_score':f1_score}

    def validation_epoch_end(self, outs) -> None:
        _,_,_ = self._epoch_end_logic(outs, prefix='val_')

    def test_step(self, batch, batch_idx):
        prefix = 'test_'
        loss, acc, f1_score = self._step_logic(batch, prefix=prefix)
        return {f'{prefix}loss':loss, f'{prefix}acc':acc, f'{prefix}f1_score':f1_score}

    def test_epoch_end(self, outs) -> None:
        _,_,_ = self._epoch_end_logic(outs, prefix='test_')

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


# init the autoencoder
cnnLightning = MNISTClassifier(the_device='cuda')
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [50000, 10000])
train_loader = utils.data.DataLoader(train_dataset, batch_size=32, num_workers=4)
val_loader = utils.data.DataLoader(val_dataset, batch_size=32, num_workers=4)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(log_every_n_steps=1, num_sanity_val_steps=0, max_epochs=3, accelerator='gpu', fast_dev_run=False)
trainer.fit(model=cnnLightning, train_dataloaders=train_loader, val_dataloaders=val_loader) # , val_dataloaders=val_loader

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | conv1    | Conv2d     | 320   
1 | conv2    | Conv2d     | 18.5 K
2 | dropout1 | Dropout2d  | 0     
3 | dropout2 | Dropout2d  | 0     
4 | fc1      | LazyLinear | 0     
5 | fc2      | LazyLinear | 0     
----------------------------------------
18.8 K    Trainable params
0         Non-trainable params
18.8 K    Total params
0.075     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
history = trainer.model.history.copy()
for metric, tensors_list in history.items():
    if isinstance(tensors_list, list) and len(tensors_list) > 0 and isinstance(tensors_list[0], torch.Tensor):
        history[metric] = np.array([np.array(tens.detach().cpu()) for tens in tensors_list]).round(5)

In [None]:
# debugging cell: dummy example for demonstration purposes:
# the code in this cell can be found in forward()

# len(hier_y_probs) == num_hierarchy_output_layers, 
# while each element is 2D tensor of shape (batch_size, num_classes of the i_th output layer in hierarchy)
hier_y_pred = []
for tensor_output in demo_model(a):
    hier_y_pred.append(F.softmax(tensor_output, dim=1))
hier_y_pred

[tensor([[0.5039, 0.4961],
         [0.5040, 0.4960],
         [0.5040, 0.4960],
         [0.5040, 0.4960]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.2513, 0.2444, 0.2263, 0.2780],
         [0.2514, 0.2444, 0.2262, 0.2780],
         [0.2513, 0.2444, 0.2263, 0.2780],
         [0.2513, 0.2444, 0.2262, 0.2780]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.2561, 0.2539, 0.2301, 0.2599],
         [0.2561, 0.2538, 0.2302, 0.2599],
         [0.2561, 0.2538, 0.2301, 0.2599],
         [0.2561, 0.2539, 0.2301, 0.2599]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3410, 0.3629, 0.2960],
         [0.3410, 0.3629, 0.2961],
         [0.3410, 0.3629, 0.2960],
         [0.3410, 0.3630, 0.2960]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3407, 0.3076, 0.3518],
         [0.3407, 0.3075, 0.3518],
         [0.3407, 0.3075, 0.3518],
         [0.3407, 0.3075, 0.3517]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3223, 0.3448, 0.3329],
         [0.3223, 0.3448, 0.3329],
         [0.3223, 0.3449, 0.3329],
         [0.3

In [None]:
# debugging cell: dummy example for demonstration purposes:
# the code in this cell can be found in train_step() (or actually, in _step_logic())

# x, y = batch
# hier_y_pred = self(x)
y = torch.tensor([1,2,5,8]) # remove this in train_step()
# y now has shape (batch_size, len(hier_y_pred))
# in other words, each row now consists of `len(hier_y_pred)` output layers' labels of 1 sample
y = torch.tensor([labelToHierarchy[i] for i in range(len(y))], dtype=int) # alternatively, range(y.size()[0])
print(y)
# transpose y to take each row (batch of labels) with its corresponding hier_y_pred row (batch of predicted labels)
# in other words, each row now consists of `batch_size` labels of the i_th output layer
y = y.T
print(y)
print()
# softmax_output_layer elements each has shape (batch_size, num_classes of the i_th output layer in hierarchy)
# side note: the wording of "i_th output layers" refers to the strings mentioned in 
# output_order argument of HierarchalModel() used in create_hnn_model_arch()
for i, softmax_output_layer in enumerate(hier_y_pred):
    # 'cur' refers to current output layer
    y_cur = y[i]
    y_pred_cur = softmax_output_layer
    print(y_cur)
    print(y_pred_cur)
    break

tensor([[0, 0, 3, 2, 2, 2],
        [0, 1, 3, 0, 2, 2],
        [0, 1, 3, 1, 2, 2],
        [1, 3, 0, 2, 0, 2]])
tensor([[0, 0, 0, 1],
        [0, 1, 1, 3],
        [3, 3, 3, 0],
        [2, 0, 1, 2],
        [2, 2, 2, 0],
        [2, 2, 2, 2]])

tensor([0, 0, 0, 1])
tensor([[0.5039, 0.4961],
        [0.5040, 0.4960],
        [0.5040, 0.4960],
        [0.5040, 0.4960]], grad_fn=<SoftmaxBackward0>)


In [None]:
# this function would've been useful if in training_step(), we returned 'y' as ohe (e.g., [0,0,1]) instead of integer labels (e.g., 2)
def ohe(value, size): # ohe == one hot encode
    identity = torch.eye(size)
    one_hot = identity[value]
    if (value == -1):
        # then we want the ohe list to be all zeros
        one_hot[-1] = 0
    # uncomment this line if you want to return 2D tensor of shape (1, size)
    # return one_hot.view(-1, size)
    return one_hot

classToHierarchy = {
    '00. selfies' : [ohe(0, 2), # diversity of colors: many or few
                     ohe(0, 4), # many colors: selfies, memes (fmm or emm), or eGreetingsAndMisc (+1 if none of them)
                     ohe(-1, 4), # few colors: socialMedia (fsm or esm), fTxtMssgs, academic (photos or digital) (+1 if none of them)
                     ohe(-1, 3), # meme type: fmm or emm (+1 if none of them)
                     ohe(-1, 3), # social media type: fsm or esm (+1 if none of them)
                     ohe(-1, 3)], # academic type: photos or digital (+1 if none of them)
    '10. fmemes' : [ohe(0, 2), ohe(1, 4), ohe(-1, 4), ohe(0, 3), ohe(-1, 3), ohe(-1, 3)],
    '20. ememes' : [ohe(0, 2), ohe(1, 4), ohe(-1, 4), ohe(1, 3), ohe(-1, 3), ohe(-1, 3)],
    '30. fSocialMedia' : [ohe(1, 2), ohe(-1, 4), ohe(0, 4), ohe(-1, 3), ohe(0, 3), ohe(-1, 3)],
    '40. eSocialMedia' : [ohe(1, 2), ohe(-1, 4), ohe(0, 4), ohe(-1, 3), ohe(1, 3), ohe(-1, 3)],
    '50. fTxtMssgs' : [ohe(1, 2), ohe(-1, 4), ohe(1, 4), ohe(-1, 3), ohe(-1, 3), ohe(-1, 3)],
    '70. eGreetingAndMisc' : [ohe(0, 2), ohe(2, 4), ohe(-1, 4), ohe(-1, 3), ohe(-1, 3), ohe(-1, 3)],
    '81. academicPhotos' : [ohe(1, 2), ohe(-1, 4), ohe(2, 4), ohe(-1, 3), ohe(-1, 3), ohe(0, 3)],
    '82. academicDigital' : [ohe(1, 2), ohe(-1, 4), ohe(2, 4), ohe(-1, 3), ohe(-1, 3), ohe(1, 3)],
}
classToHierarchy['00. selfies']