In [2]:
%config Completion.use_jedi = False
%reload_ext autoreload
%autoreload 2

import sys
import os
sys.path.append("../")
#os.environ["MLFLOW_TRACKING_URI"] = 'http://localhost:5000/'

In [3]:
from torchvision import transforms
import torch

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from src.custom_datasets import MultiLabelDataModule
from src.model import MultiLabelClassifier

# 1. Introduction

As mentioned in the readme and in the data exploration notebook, I'll built a multi-label classifier. So the model can predict more than one class per image (e.g. the colour the card and whether it's a creature or special card).  For this purpose I'll use the binary crossentropy loss with [logits](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html). This loss treats the outputs/logits independently and is [suitable for multi-label classification](https://discuss.pytorch.org/t/is-there-an-example-for-multi-class-multilabel-classification-in-pytorch/53579/7). For this experimnet I'll use a Resnet network as a backbone.


<img src="../img/NeuralNetwork_MultiLabel_Concept.svg" width=500 height=400 align="center"/>

**Training:** via pytorch Lightning with torch 2.0.1  <br/>
**Accelerator:** MPS (M1 Max) <br/>
**Logger:** MLFlow (local - `mlflow ui --backend-store-uri Coding_Projects/MLFlow_runs`) <br/>
**Datamodule:** Custom Module in src (built on the CustomDataset [MultiLabelImageFolder](../src/custom_datasets.py)) <br/>
**metrics:** MultiLabel Accurracy will be the benchmark. I will also track the Precision and Recall with [MetricCollection](https://torchmetrics.readthedocs.io/en/stable/pages/overview.html?highlight=metriccollection#metriccollection) of torchmetrics <br/>

# 2. Experiment Design

I'll run 3 different experiments with the same augmentations, image input size and hyperparameters. The backbone will be Resnet18 to keep the model small. The training will run for 30 epochs per each experiment.

1. Backbone pretrained frozen
2. Backbone pretrained unfrozen
3. Backbone untrained unfrozen (normalilzed on training data)

At the end I'll take the best model and will use this model in a streamlit app to visualize the model's inference.

## General Settings

In [9]:
# MLFlow Experiment Name
project_name = "Magic The Gathering - Multilabel Classification"

# Other MLFlow related parameters
root_dir = "/Users/ryoshibata/Coding_projects"
local_mlflow_uri = f"{root_dir}/MLFlow_runs/"
#artifact_path = f"{local_mlflow_uri}/models/"

# Local paths where I stored the data
data_dirs = {"train": "../data/0.7-0.15-0.15_split/train/",
             "test": "../data/0.7-0.15-0.15_split/test/",
             "val": "../data/0.7-0.15-0.15_split/val/"}

# Experiment Settings
batch_size = 32
hidden_size = 1024
lr = 0.001
num_classes = 10
n_epochs = 50


# Transform/Augmentations

- 

In [5]:
train_transform = transforms.Compose([transforms.RandomRotation(degrees=(0, 180)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.Resize((312, 445)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

inference_transform = transforms.Compose([transforms.Resize((312, 445)),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406],
                                                               [0.229, 0.224, 0.225])])

# Experiment 1 - Backbone pretrained and frozen weights

In [8]:
run_name = "Experiment_1-Resnet18-pretrained_frozen-weights"

checkpoint_callback = ModelCheckpoint(dirpath=f"./checkpoints/{run_name}/",
                                      save_top_k=2,
                                      monitor="val_MultilabelAccuracy_epoch",
                                      mode="max")

mlf_logger = MLFlowLogger(experiment_name=project_name,
                          run_name=run_name,
                          tracking_uri=local_mlflow_uri,
                          log_model=True)

mtg_data = MultiLabelDataModule(data_dirs=data_dirs,
                                train_transform=train_transform,
                                inference_transform=inference_transform,
                                batch_size=batch_size)


mlf_logger.log_hyperparams(train_transform.__dict__)

In [10]:
backbone_config = {
    "freeze_params": True,
    "backbone": "resnet18",
    "weights": "IMAGENET1K_V1",
}
# model
multilabel_model = MultiLabelClassifier(
    backbone_config=backbone_config,
    num_classes=num_classes,
    hidden_size_1=hidden_size,
    hidden_size_2=hidden_size,
    lr=lr,
)

# enable torch 2.x new features compile
torch.compile(multilabel_model)

# train model
trainer = Trainer(
    max_epochs=n_epochs,
    log_every_n_steps=5,
    logger=mlf_logger,
    accelerator="mps",
    profiler="simple",
    callbacks=[checkpoint_callback],
)

trainer.fit(model=multilabel_model, datamodule=mtg_data)


Using cache found in /Users/ryoshibata/.cache/torch/hub/pytorch_vision_main
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type              | Params
----------------------------------------------------
0 | backbone      | ResnetBackbone    | 11.2 M
1 | classifier    | ClassifierHead    | 1.6 M 
2 | criterion     | BCEWithLogitsLoss | 0     
3 | train_metrics | MetricCollection  | 0     
4 | valid_metrics | MetricCollection  | 0     
5 | test_metrics  | MetricCollection  | 0     
----------------------------------------------------
1.6 M     Trainable params
11.2 M    Non-trainable params
12.8 M    Total params
51.047    Total estimated model params size (MB)


Epoch 49: 100%|██████████| 42/42 [00:04<00:00,  8.46it/s, v_num=f455]      

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 42/42 [00:05<00:00,  8.31it/s, v_num=f455]


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                        

In [11]:
trainer.test(multilabel_model,
             datamodule=mtg_data,
             ckpt_path="best")

Restoring states from the checkpoint path at /Users/ryoshibata/PycharmProjects/MultiLabelClassification/notebooks/checkpoints/Experiment_1-Resnet18-pretrained_frozen-weights/epoch=49-step=2100.ckpt
Loaded model weights from the checkpoint at /Users/ryoshibata/PycharmProjects/MultiLabelClassification/notebooks/checkpoints/Experiment_1-Resnet18-pretrained_frozen-weights/epoch=49-step=2100.ckpt


Testing DataLoader 0: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s]


TEST Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                       

[{'test_loss': 0.17613567411899567,
  'test_MultilabelAccuracy': 0.9242525100708008,
  'test_MultilabelPrecision': 0.861706018447876,
  'test_MultilabelRecall': 0.8271732330322266}]

In [12]:
torch._C._mps_emptyCache()

In [7]:
#multilabel_model.load_from_checkpoint("./checkpoints/epoch=2-step=63.ckpt")

# Experiment 2 - Backbone pretrained, not frozen parameters

In [13]:
run_name = "Experiment_2-Resnet18-pretrained_unfrozen-weights"

checkpoint_callback = ModelCheckpoint(dirpath=f"./checkpoints/{run_name}/",
                                      save_top_k=2,
                                      monitor="val_MultilabelAccuracy_epoch",
                                      mode="max")

mlf_logger = MLFlowLogger(experiment_name=project_name,
                          run_name=run_name,
                          tracking_uri=local_mlflow_uri,
                          log_model=True)

mtg_data = MultiLabelDataModule(data_dirs=data_dirs,
                                train_transform=train_transform,
                                inference_transform=inference_transform,
                                batch_size=batch_size)

mlf_logger.log_hyperparams(train_transform.__dict__)

In [14]:
backbone_config = {
    "freeze_params": False,
    "backbone": "resnet18",
    "weights": "IMAGENET1K_V1",
}

# Model
multilabel_model = MultiLabelClassifier(
    backbone_config=backbone_config,
    num_classes=num_classes,
    hidden_size_1=hidden_size,
    hidden_size_2=hidden_size,
    lr=lr,
)

# enable torch 2.x new features compile
torch.compile(multilabel_model)

# train model
trainer = Trainer(
    max_epochs=n_epochs,
    log_every_n_steps=5,
    logger=mlf_logger,
    accelerator="mps",
    profiler="simple",
    callbacks=[checkpoint_callback],
)

trainer.fit(model=multilabel_model, datamodule=mtg_data)

Using cache found in /Users/ryoshibata/.cache/torch/hub/pytorch_vision_main
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type              | Params
----------------------------------------------------
0 | backbone      | ResnetBackbone    | 11.2 M
1 | classifier    | ClassifierHead    | 1.6 M 
2 | criterion     | BCEWithLogitsLoss | 0     
3 | train_metrics | MetricCollection  | 0     
4 | valid_metrics | MetricCollection  | 0     
5 | test_metrics  | MetricCollection  | 0     
----------------------------------------------------
12.8 M    Trainable params
0         Non-trainable params
12.8 M    Total params
51.047    Total estimated model params size (MB)


Epoch 49: 100%|██████████| 42/42 [00:11<00:00,  3.61it/s, v_num=d6c7]      

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 42/42 [00:11<00:00,  3.60it/s, v_num=d6c7]


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                        

In [15]:
trainer.test(multilabel_model,
             datamodule=mtg_data,
             ckpt_path="best")

Restoring states from the checkpoint path at /Users/ryoshibata/PycharmProjects/MultiLabelClassification/notebooks/checkpoints/Experiment_2-Resnet18-pretrained_unfrozen-weights/epoch=48-step=2058.ckpt
Loaded model weights from the checkpoint at /Users/ryoshibata/PycharmProjects/MultiLabelClassification/notebooks/checkpoints/Experiment_2-Resnet18-pretrained_unfrozen-weights/epoch=48-step=2058.ckpt


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 11.81it/s]


TEST Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                       

[{'test_loss': 0.053050845861434937,
  'test_MultilabelAccuracy': 0.9833887219429016,
  'test_MultilabelPrecision': 0.9800772070884705,
  'test_MultilabelRecall': 0.9765597581863403}]

In [11]:
torch._C._mps_emptyCache()

# Experiment 3 - Backbone untrained

In [16]:
train_transform = transforms.Compose([transforms.RandomRotation(degrees=(0, 180)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.Resize((312, 445)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.50476729, 0.48440304, 0.46218942],
                                                           [0.30981703, 0.3034715 , 0.30258951])])

inference_transform = transforms.Compose([transforms.Resize((312, 445)),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.50476729, 0.48440304, 0.46218942],
                                                               [0.30981703, 0.3034715 , 0.30258951])])

In [13]:
run_name = "Experiment_3-Resnet18-untrained"

checkpoint_callback = ModelCheckpoint(dirpath=f"./checkpoints/{run_name}/",
                                      save_top_k=2,
                                      monitor="val_MultilabelAccuracy_epoch",
                                      mode="max")

mlf_logger = MLFlowLogger(experiment_name=project_name,
                          run_name=run_name,
                          tracking_uri=local_mlflow_uri,
                          log_model=True)

mtg_data = MultiLabelDataModule(data_dirs=data_dirs,
                                train_transform=train_transform,
                                inference_transform=inference_transform,
                                batch_size=batch_size)

mlf_logger.log_hyperparams(train_transform.__dict__)

In [14]:
backbone_config = {
    "freeze_params": False,
    "backbone": "resnet18",
    "weights": None,
}
# model
multilabel_model = MultiLabelClassifier(
    backbone_config=backbone_config,
    num_classes=num_classes,
    hidden_size_1=hidden_size,
    hidden_size_2=hidden_size,
    lr=lr,
)

# enable torch 2.x new features compile
torch.compile(multilabel_model)

# train model
trainer = Trainer(
    max_epochs=n_epochs,
    log_every_n_steps=5,
    logger=mlf_logger,
    accelerator="mps",
    profiler="simple",
    callbacks=[checkpoint_callback],
)

trainer.fit(model=multilabel_model, datamodule=mtg_data)

  rank_zero_warn(
Using cache found in /Users/ryoshibata/.cache/torch/hub/pytorch_vision_main
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type              | Params
----------------------------------------------------
0 | backbone      | ResnetBackbone    | 11.2 M
1 | classifier    | ClassifierHead    | 1.6 M 
2 | criterion     | BCEWithLogitsLoss | 0     
3 | train_metrics | MetricCollection  | 0     
4 | valid_metrics | MetricCollection  | 0     
5 | test_metrics  | MetricCollection  | 0     
----------------------------------------------------
12.8 M    Trainable params
0         Non-trainable params
12.8 M    Total params
51.047    Total estimated model params size (MB)


Epoch 29: 100%|██████████| 21/21 [00:13<00:00,  1.57it/s, v_num=6cea]      

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 21/21 [00:13<00:00,  1.57it/s, v_num=6cea]


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                        

In [None]:
trainer.test(multilabel_model,
             datamodule=mtg_data,
             ckpt_path="best")

In [16]:
torch._C._mps_emptyCache()

: 