In [1]:
import torch
import pytorch_lightning as pl

from pytorch_lightning.loggers import TensorBoardLogger

In [2]:
import config
from dataset import DataModule
from model import ClassificationModel

In [3]:
torch.random.manual_seed(config.RANDOM_SEED)
pl.seed_everything(config.RANDOM_SEED)

Seed set to 42


42

In [4]:
dm = DataModule(
    data_path=config.DATA_PATH,
    batch_size=config.BATCH_SIZE,
    num_workers=config.NUM_WORKERS,
    train_test_ratio=config.TRAIN_TEST_RATIO,
    train_val_ratio=config.TRAIN_VAL_RATIO,
)

model = ClassificationModel(learning_rate=config.LEARNING_RATE)

logger = TensorBoardLogger("tb_logs", name="my_model")

trainer = pl.Trainer(
    accelerator=config.ACCELERATOR,
    devices=config.DEVICES,
    min_epochs=config.MIN_EPOCHS,
    max_epochs=config.MAX_EPOCHS,
    enable_checkpointing=False,
    logger=logger,
)

Downloading: "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth" to /Users/evlko/.cache/torch/hub/checkpoints/squeezenet1_0-b66bff10.pth
100%|██████████| 4.78M/4.78M [00:00<00:00, 20.8MB/s]
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Заметки про модели:
* Inception v3 ожидает на вход 299x299
* Другие 224x224 (хотя иногда можно и другие, например, resnet ок всё что кратно 32)
* mobile net показал все более лучшее с LEARNING_RATE = 0.0001 вместо 0.001

In [5]:
trainer.fit(model, dm)
trainer.validate(model, dm)
trainer.test(model, dm)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | SqueezeNet       | 736 K 
1 | loss_fn   | CrossEntropyLoss | 0     
2 | accuracy  | BinaryAccuracy   | 0     
3 | precision | BinaryPrecision  | 0     
-----------------------------------------------
736 K     Trainable params
0         Non-trainable params
736 K     Total params
2.946     Total estimated model params size (MB)


Epoch 19: 100%|██████████| 56/56 [00:21<00:00,  2.59it/s, v_num=0, val_acc=0.944, val_precision=0.952]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 56/56 [00:21<00:00,  2.59it/s, v_num=0, val_acc=0.944, val_precision=0.952]
Validation DataLoader 0: 100%|██████████| 7/7 [00:01<00:00,  6.67it/s]


Testing DataLoader 0: 100%|██████████| 7/7 [00:01<00:00,  5.48it/s]


[{'test_loss': 0.2144719511270523,
  'test_acc': 0.9318181872367859,
  'test_precision': 0.9317901730537415}]

In [6]:
torch.save(model.state_dict(), "squeezenet_weights.pth")