# NN Training

Plan:

* Fine-tune a pretrained image classification model
    * Compare a few small efficient ones
    * Use old PyTorch utility functions
    * If enough time, redo in Lightning and write utility functions for it


[Nets in torchvision](https://pytorch.org/vision/stable/models.html), though could look at those through timm or directly on huggingface.

The image resolution (224x224) matches many standard nets (e.g. [EfficientNet](https://pytorch.org/vision/stable/models/efficientnet.html) B0, [ResNet](https://pytorch.org/vision/stable/models/resnet.html)(s) 34 and 50).
The [RexNet family](https://github.com/clovaai/rexnet) is also all of that resolution, and although less common, supposedly more efficient in training than EfficientNet.

EfficientNet B0 and RexNet 1.0 have around 5M parameters, and EfficientNet B2 and RexNet 1.5 have around 10M.
If venturing farther, then [EfficientNetV2](https://pytorch.org/vision/stable/models/efficientnetv2.html) (the small one, which would still upscale to 384x384) would be even more efficient, but at around 20M parameters (between those of EfficientNet B4 and B5).

In [1]:
# import sys
# sys.path.append('..')
from pytorch_utils import *
from lightning_utils import *
from pytorch_vision_utils import *
%load_ext autoreload
%autoreload 2

In [2]:
data_path = r'E:\Data_and_Models\Kaggle_Cards'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Pure PyTorch Version

### Model Creation

In [11]:
# The HWC -> CHW permutation seems to happen on its own (or the data info on Kaggle is wrong)
# dataloaders 0, 1, 2 are train, test and valid

# # EfficientNet B0 - 5.3M parameters
# dataloaders, classes = image_dataloaders(data_path, (transforms := (weights := tv.models.EfficientNet_B0_Weights.DEFAULT).transforms()), batch_size = 32)
# model = tv.models.efficientnet_b0(weights = weights).to(device)

# # EfficientNet B2 - 9.2M parameters
# dataloaders, classes = image_dataloaders(data_path, (transforms := (weights := tv.models.EfficientNet_B2_Weights.DEFAULT).transforms()), batch_size = 32)
# model = tv.models.efficientnet_b2(weights = weights).to(device)

# RexNet 1.0 - 4.8M parameters - https://huggingface.co/timm/rexnet_100.nav_in1k
model = timm.create_model('rexnet_100.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) # Cannot use len(classes) yet
dataloaders, classes = image_dataloaders(data_path, (transforms := timm.data.create_transform(**timm.data.resolve_model_data_config(model), is_training = False)), batch_size = 32)

# # RexNet 1.5 - 9.7M parameters - https://huggingface.co/timm/rexnet_150.nav_in1k
# model = timm.create_model('rexnet_150.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) # Cannot use len(classes) yet
# dataloaders, classes = image_dataloaders(data_path, (transforms := timm.data.create_transform(**timm.data.resolve_model_data_config(model), is_training = False)), batch_size = 32)

In [58]:
summ(model, input_size = (32, 3, 224, 224))

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
RexNet (RexNet)                               [32, 3, 224, 224]    [32, 53]             --                   True
├─ConvNormAct (stem)                          [32, 3, 224, 224]    [32, 32, 112, 112]   --                   True
│    └─Conv2d (conv)                          [32, 3, 224, 224]    [32, 32, 112, 112]   864                  True
│    └─BatchNormAct2d (bn)                    [32, 32, 112, 112]   [32, 32, 112, 112]   64                   True
│    │    └─Identity (drop)                   [32, 32, 112, 112]   [32, 32, 112, 112]   --                   --
│    │    └─SiLU (act)                        [32, 32, 112, 112]   [32, 32, 112, 112]   --                   --
├─Sequential (features)                       [32, 32, 112, 112]   [32, 1280, 7, 7]     --                   True
│    └─LinearBottleneck (0)                   [32, 32, 112, 112]   [32, 16, 112, 112]  

Optionally freeze feature extraction layers


In [59]:
# Make the feature extractor layers ("features" in the summary) non-trainable (re-run summary above to check)
for param in model.features.parameters(): param.requires_grad = False

# # This as well for RexNet models
for param in model.stem.parameters(): param.requires_grad = False

### Classifier Layers Replacement (already done for RexNets)

In [6]:
# NOT NEEDED FOR REXNETS
# Inspect the classifier layers to replicate its structure
model.classifier

Sequential(
  (0): Dropout(p=0.2, inplace=True)
  (1): Linear(in_features=1280, out_features=1000, bias=True)
)

In [7]:
# NOT NEEDED FOR REXNETS

# Set the number of classes to the card ones (and reset the other parameters in the classifier layer)
model.classifier = torch.nn.Sequential(
    # EfficientNet B0
    torch.nn.Dropout(p = 0.2, inplace = True),
    torch.nn.Linear(in_features = 1280, out_features = len(classes), bias = True)
    # # EfficientNet B2
    # torch.nn.Dropout(p = 0.3, inplace = True),
    # torch.nn.Linear(in_features = 1408, out_features = len(classes), bias = True)
).to(device)


# The following simpler option does not complain but fails to replace the actual parameter tensor:
#   model.classifier[1].out_features = len(classes)
# And replacing just the linear layer might be biased by the pretrained dropout one
#   model.classifier[1] = torch.nn.Linear(in_features = 1280, out_features = len(classes), bias = True).to(device)

### Training

In [4]:
# Model pipeline functions

loss_fn = nn.CrossEntropyLoss()

# Define an extra metric beside the loss
f1_fn = torchmetrics.F1Score(task = 'multiclass', num_classes = len(classes)).to(device)
accuracy_fn = torchmetrics.Accuracy(task = 'multiclass', num_classes = len(classes)).to(device)

In [6]:
# Train the model

model_name = 'RexNet10'
extra = '0_FullRetrain_Adam001_10_epochs' # Mimicking the train_combination function format
print(f'Training {model_name}_{extra}')

set_seeds(42)
results = fit(model, train_dataloader = dataloaders[0], test_dataloader = dataloaders[1],
    optimiser = torch.optim.Adam(model.parameters(), lr = 0.001), loss_fn = loss_fn,
    metric_name_and_fn = ('F1', f1_fn),
    # metric_name_and_fn = ('Accuracy', accuracy_fn),
    epochs = 10,
    writer = tensorboard_writer(experiment_name = 'Cards', model_name = model_name, extra = extra, save_dir = fr'{data_path}\runs')
    # writer = None
)

Training RexNet10_0_FullRetrain_Adam001_10_epochs
[INFO] Created SummaryWriter, saving to: runs\2024-04-08\Cards\RexNet10\0_FullRetrain_Adam001_10_epochs...


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1 training batches: 0it [00:00, ?it/s]

Epoch 1 testing batches: 0it [00:00, ?it/s]

Epoch: 1 | train_loss: 1.7339 | train_metric: 0.5124 | test_loss: 0.5602 | test_metric: 0.8314


Epoch 2 training batches: 0it [00:00, ?it/s]

Epoch 2 testing batches: 0it [00:00, ?it/s]

Epoch: 2 | train_loss: 0.5716 | train_metric: 0.8264 | test_loss: 0.2966 | test_metric: 0.9147


Epoch 3 training batches: 0it [00:00, ?it/s]

Epoch 3 testing batches: 0it [00:00, ?it/s]

Epoch: 3 | train_loss: 0.3233 | train_metric: 0.9004 | test_loss: 0.2243 | test_metric: 0.9198


Epoch 4 training batches: 0it [00:00, ?it/s]

Epoch 4 testing batches: 0it [00:00, ?it/s]

Epoch: 4 | train_loss: 0.2093 | train_metric: 0.9327 | test_loss: 0.1722 | test_metric: 0.9495


Epoch 5 training batches: 0it [00:00, ?it/s]

Epoch 5 testing batches: 0it [00:00, ?it/s]

Epoch: 5 | train_loss: 0.1696 | train_metric: 0.9484 | test_loss: 0.2843 | test_metric: 0.9217


Epoch 6 training batches: 0it [00:00, ?it/s]

Epoch 6 testing batches: 0it [00:00, ?it/s]

Epoch: 6 | train_loss: 0.1520 | train_metric: 0.9536 | test_loss: 0.2251 | test_metric: 0.9267


Epoch 7 training batches: 0it [00:00, ?it/s]

Epoch 7 testing batches: 0it [00:00, ?it/s]

Epoch: 7 | train_loss: 0.0993 | train_metric: 0.9685 | test_loss: 0.1524 | test_metric: 0.9668


Epoch 8 training batches: 0it [00:00, ?it/s]

Epoch 8 testing batches: 0it [00:00, ?it/s]

Epoch: 8 | train_loss: 0.0992 | train_metric: 0.9704 | test_loss: 0.1931 | test_metric: 0.9475


Epoch 9 training batches: 0it [00:00, ?it/s]

Epoch 9 testing batches: 0it [00:00, ?it/s]

Epoch: 9 | train_loss: 0.0868 | train_metric: 0.9748 | test_loss: 0.2121 | test_metric: 0.9425


Epoch 10 training batches: 0it [00:00, ?it/s]

Epoch 10 testing batches: 0it [00:00, ?it/s]

Epoch: 10 | train_loss: 0.0903 | train_metric: 0.9732 | test_loss: 0.2218 | test_metric: 0.9390


In [7]:
# save_model(model, fr'{data_path}\Models', f'{model_name}_{extra}.pth')

[INFO] Saving model to: E:\Data_and_Models\Kaggle_Cards\Models\RexNet10_0_FullRetrain_Adam001_10_epochs.pth


WindowsPath('E:/Data_and_Models/Kaggle_Cards/Models/RexNet10_0_FullRetrain_Adam001_10_epochs.pth')

## PyTorch Lightning Version

### Model Creation

In [3]:
model = timm.create_model('rexnet_100.nav_in1k', pretrained = True, num_classes = 53)
# for param in model.features.parameters(): param.requires_grad = False
# for param in model.stem.parameters(): param.requires_grad = False

In [4]:
transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(model), is_training = False)
ldata = LocalImageDataModule(data_path, transform = transforms, batch_size = 32)

In [5]:
# Model pipeline functions

loss_fn = nn.CrossEntropyLoss()

# Define an extra metric beside the loss
f1_fn = torchmetrics.F1Score(task = 'multiclass', num_classes = 53)
accuracy_fn = torchmetrics.Accuracy(task = 'multiclass', num_classes = 53)

# # General prediction pipeline
# with torch.inference_mode(): pred_logit = model(transforms(img).unsqueeze(dim = 0).to(device)) # Prepend "batch" dimension (-> [batch_size, color_channels, height, width])
# pred_prob = torch.softmax(pred_logit, dim = 1)
# return torch.argmax(pred_prob, dim = 1)

def prediction_fn(logits): return torch.argmax(torch.softmax(logits, dim = 1), dim = 1)


In [6]:
lmod = Strike(model, loss_fn = loss_fn, metric_name_and_fn = ('F1', f1_fn),
               optimiser_factory = lambda m: torch.optim.Adam(m.parameters(), lr = m.learning_rate),
               prediction_fn = prediction_fn, learning_rate = 0.001, log_at_every_step = False)

summ(lmod, input_size = (32, 3, 224, 224))

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
Strike (Strike)                                    [32, 3, 224, 224]    [32, 53]             --                   Partial
├─RexNet (model)                                   [32, 3, 224, 224]    [32, 53]             --                   Partial
│    └─ConvNormAct (stem)                          [32, 3, 224, 224]    [32, 32, 112, 112]   --                   False
│    │    └─Conv2d (conv)                          [32, 3, 224, 224]    [32, 32, 112, 112]   (864)                False
│    │    └─BatchNormAct2d (bn)                    [32, 32, 112, 112]   [32, 32, 112, 112]   (64)                 False
│    └─Sequential (features)                       [32, 32, 112, 112]   [32, 1280, 7, 7]     --                   False
│    │    └─LinearBottleneck (0)                   [32, 32, 112, 112]   [32, 16, 112, 112]   (896)                False
│    │    └─LinearBottleneck (1)

### Training

In [7]:
experiment_name = 'FullRetrain_EarlyStop'
# experiment_name = 'ClassRetrain_EarlyStop'
model_name = 'RexNet10'
extra = 'Adam001_max10_epochs' # Mimicking the train_combination function format
# extra = 'ClassRetrain_AdamAutoLR_max10_epochs' # Mimicking the train_combination function format
print(f'Training {experiment_name}_{model_name}_{extra}')

trainer = L.Trainer(
    accelerator = 'gpu', devices = 1, 
    min_epochs = 1, max_epochs = 10,
    callbacks = [
        EarlyStopping(monitor = 'val_loss', mode = 'min', min_delta = 0.01, patience = 3, verbose = True),
        ModelCheckpoint(monitor = 'val_loss', mode = 'min', save_top_k = 3, verbose = True,
                        dirpath = (checkpoints_path := fr'{data_path}\Models\{experiment_name}_{model_name}_{extra}'),
                        filename = '{epoch}-{val_loss:.2f}-{val_F1:.2f}', enable_version_counter = True),
                        
        ## Unfortunately need to forego automatic lr if want to use early stopping properly (i.e. on validation loss) since the steps are counted and throw off the checkpointing: https://github.com/Lightning-AI/pytorch-lightning/issues/19575
        # EarlyStopping(monitor = 'train_F1', min_delta = 0.01, mode = 'max', patience = 3, verbose = True),
        # LearningRateFinder(num_training_steps = 100),
        # # IteratedLearningRateFinder(at_epochs = [3, 6, 9], num_training_steps = 100),
        # # LearningRateMonitor(logging_interval = 'epoch'), # No need if not readjusting it after epoch 0
    ],
    logger = TBLogger(experiment_name = experiment_name, model_name = model_name, extra = extra, save_dir = fr'{data_path}\runs'), # The default logger is also good
    profiler = PyTorchProfiler(filename = f'{model_name}_{extra}'),

    ## Pre-training checks
    # fast_dev_run = False # Tries a couple of batches for training, validating and testing first, just to check that everything works
    # overfit_batches = 10, # Check able to overfit few batches; int for count or float for proportion
)

Training FullRetrain_EarlyStop_RexNet10_Adam001_max10_epochs


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [26]:
print(f'Training {experiment_name}_{model_name}_{extra}')

set_seeds(42)
# trainer.fit(lmod, dataloaders[0], dataloaders[2])
trainer.fit(lmod, ldata,
            # ckpt_path = trainer.checkpoint_callback.best_model_path # Continue training from a checkpoint
)

# NOTE: trainer will prevent from re-fitting, so if the model changed need to re-declare trainer above in order to train again

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | RexNet            | 3.6 M 
1 | loss_fn   | CrossEntropyLoss  | 0     
2 | metric_fn | MulticlassF1Score | 0     
------------------------------------------------
67.9 K    Trainable params
3.5 M     Non-trainable params
3.6 M     Total params
14.335    Total estimated model params size (MB)


Training ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 2.933
Epoch 0, global step 239: 'val_loss' reached 2.93327 (best 2.93327), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=0-val_loss=2.93-val_F1=0.11.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 478: 'val_loss' reached 3.07336 (best 2.93327), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=1-val_loss=3.07-val_F1=0.22.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 717: 'val_loss' reached 3.30281 (best 2.93327), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=2-val_loss=3.30-val_F1=0.11.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.104 >= min_delta = 0.01. New best score: 2.829
Epoch 3, global step 956: 'val_loss' reached 2.82916 (best 2.82916), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=3-val_loss=2.83-val_F1=0.22.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 1195: 'val_loss' reached 2.86953 (best 2.82916), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=4-val_loss=2.87-val_F1=0.22.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 1434: 'val_loss' reached 2.90167 (best 2.82916), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=5-val_loss=2.90-val_F1=0.33.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 2.813
Epoch 6, global step 1673: 'val_loss' reached 2.81331 (best 2.81331), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=6-val_loss=2.81-val_F1=0.33.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.279 >= min_delta = 0.01. New best score: 2.534
Epoch 7, global step 1912: 'val_loss' reached 2.53411 (best 2.53411), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=7-val_loss=2.53-val_F1=0.33.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 2151: 'val_loss' reached 2.73769 (best 2.53411), saving model to 'E:\\Data_and_Models\\Kaggle_Cards\\Models\\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs\\epoch=8-val_loss=2.74-val_F1=0.22.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 2390: 'val_loss' was not in top 3
`Trainer.fit` stopped: `max_epochs=10` reached.


In [13]:
# Need to give arguments again since many are non-pickleable with .save_hyperparameters
#   Could solve by moving those ones to a function producing the class
bestmod = Strike.load_from_checkpoint(trainer.checkpoint_callback.best_model_path,
    model = model, loss_fn = loss_fn, metric_name_and_fn = ('F1', f1_fn),
    optimiser_factory = lambda m: torch.optim.Adam(m.parameters(), lr = m.learning_rate),
    prediction_fn = prediction_fn, learning_rate = 0.001, log_at_every_step = False)

trainer.validate(bestmod, ldata)
trainer.test(bestmod, ldata)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_F1             0.3333333432674408
        val_loss            2.5341060161590576
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test_F1            0.2222222238779068
        test_loss            2.985792636871338
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 2.985792636871338, 'test_F1': 0.2222222238779068}]

In [30]:

# save_model(bestmod, fr'{data_path}\Models', f'{experiment_name}_{model_name}_{extra}.pth')

[INFO] Saving model to: E:\Data_and_Models\Kaggle_Cards\Models\ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs.pth


WindowsPath('E:/Data_and_Models/Kaggle_Cards/Models/ClassRetrain_EarlyStop_RexNet10_Adam001_max10_epochs.pth')