In [1]:
%load_ext autoreload
%autoreload 2

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset.mnist_data_module import MNISTDataModule
from model.lenet_edl import LeNetEDL

from settings import consts, model_settings

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


In [2]:
LOGITS_TO_EVIDENCE = 'relu'
EDL_LOSS = 'mse'

data_module = MNISTDataModule()
model = LeNetEDL(
    logits_to_evidence=LOGITS_TO_EVIDENCE,
    loss_function=EDL_LOSS
    )
callbacks = [
    ModelCheckpoint(
        filename=f'LeNetEDL-{LOGITS_TO_EVIDENCE}-{EDL_LOSS}'+'-{epoch}-{validation/accuracy:.3f}',
        monitor='validation/accuracy',
        mode='max',
        save_top_k=1,
        verbose=True,
    )
]
trainer = Trainer(
    max_epochs=model_settings.NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=consts.SAVE_PATH,
    callbacks=callbacks
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [3]:
trainer.fit(
    model,
    train_dataloader=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)


  | Name          | Type       | Params
---------------------------------------------
0 | base_model    | BaseModel  | 10.0 M
1 | accuracy      | Accuracy   | 0     
2 | loss_function | EDLMSELoss | 0     
---------------------------------------------
10.0 M    Trainable params
0         Non-trainable params
10.0 M    Total params
40.124    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 937: validation/accuracy reached 0.85370 (best 0.85370), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=0-validation/accuracy=0.854.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 1875: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 2813: validation/accuracy reached 0.86660 (best 0.86660), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=2-validation/accuracy=0.867.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 3751: validation/accuracy reached 0.95740 (best 0.95740), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=3-validation/accuracy=0.957.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 4689: validation/accuracy reached 0.97270 (best 0.97270), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=4-validation/accuracy=0.973.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 5627: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 6565: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 7503: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 8441: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 9379: validation/accuracy reached 0.97350 (best 0.97350), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=9-validation/accuracy=0.974.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 10317: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 11255: validation/accuracy reached 0.97630 (best 0.97630), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=11-validation/accuracy=0.976.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 12193: validation/accuracy reached 0.97840 (best 0.97840), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=12-validation/accuracy=0.978.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 13131: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 14069: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 15007: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 16, global step 15945: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 17, global step 16883: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 18, global step 17821: validation/accuracy reached 0.98250 (best 0.98250), saving model to "../output/lightning_logs/version_1/checkpoints/LeNetEDL-relu-mse-epoch=18-validation/accuracy=0.983.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 18759: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 20, global step 19697: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 21, global step 20635: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 22, global step 21573: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 23, global step 22511: validation/accuracy was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 24, global step 23449: validation/accuracy was not in top 1
