In [1]:
import torch
import numpy as np
import random
import os

import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers.wandb import WandbLogger

In [2]:
from modules.lightningCNN import ResNet_pl
from modules.dataModule import CIFAR10_pl

In [3]:
# 乱数固定
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(SEED)

# Trainerの準備
 trainerとはpytorch lightningにおいて学習・テスト，ログの記録，モデルの保存などを自動で行ってくれるクラスのことです．
また，GPUの管理を一括で行なってくれるため，GPUを気にせずに学習モデルやデータローダーなどを作成できます．(わざわざ.to(device)などの記述をしなくてもいい)

In [4]:
# csv logger
csv_logger = CSVLogger('logs', name='cifar10')

In [5]:
# wandb logger
wandb_logger = WandbLogger(
    project='cifar10',
    name='common resnet torchmetrics',
)

In [6]:
loggers = [csv_logger, wandb_logger]

In [7]:
# checkpoint callback
checkpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    dirpath='best_models',
    filename='cifar10-{epoch:02d}-{val_loss:.2f}',
)

In [8]:
# early stopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
)

In [9]:
# trainer
DEVICES = [0] # 使用するGPUの番号をリスト形式で指定

trainer = Trainer(
    accelerator='cuda',
    devices=DEVICES,
    max_epochs=10,
    callbacks=[checkpoint, early_stopping],
    logger=loggers,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# 学習

In [10]:
# data module
dataset = CIFAR10_pl(batch_size=512, download=True)

In [11]:
# model
model = ResNet_pl(num_class=10, batch_size=512, lr=0.001)

In [12]:
# do train !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
trainer.fit(model, dataset)

You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33myammer2[0m. Use [1m`wandb login --relogin`[0m to force relogin


Files already downloaded and verified


/home/yamamoto/.local/share/virtualenvs/LightningTemplate-1mmk8bif/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory best_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type             | Params
---------------------------------------------------
0 | resnet        | ResNet           | 23.5 M
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.190    Total estimated model params size (MB)


len(trainset): 40001
len(valset): 9999
Epoch 9: 100%|██████████| 79/79 [00:04<00:00, 18.04it/s, v_num=iy0c, val_loss=1.050]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 79/79 [00:04<00:00, 16.80it/s, v_num=iy0c, val_loss=1.050]


In [13]:
# print best model path
print(checkpoint.best_model_path)
best_model_path = checkpoint.best_model_path

best_models/cifar10-epoch=09-val_loss=1.05-v13.ckpt


# テスト

In [14]:
model = ResNet_pl.load_from_checkpoint(best_model_path)
trainer.test(model, dataset)

Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


len(testset): 10000
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 32.97it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.6503207087516785
         test_f1            0.6409515738487244
        test_loss           1.0940154790878296
     test_precision         0.6586108207702637
       test_recall          0.6503207087516785
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.0940154790878296,
  'test_acc': 0.6503207087516785,
  'test_f1': 0.6409515738487244,
  'test_precision': 0.6586108207702637,
  'test_recall': 0.6503207087516785}]