# 主流程文件 Main Training Script

> 主训练脚本入口，调用各模块进行模型训练
> 
> The main entry point for running training, orchestrating all modules for model training

## 简介/Description:
main 模块是项目的主训练入口。它结合了 core 模块中的任务定义和 data 模块中的数据加载功能，通过调用 PyTorch Lightning 的 Trainer 对模型进行训练。用户可以通过配置类快速切换不同的数据集、模型和训练策略，灵活完成实验任务。

The main module serves as the primary entry point for training. It combines task definitions from the core module and data loading from the data module to execute model training via PyTorch Lightning’s Trainer. Users can flexibly switch between different datasets, models, and training strategies through configuration classes to perform experiments.

## 主要符号/Main symbols:

- Trainer: PyTorch Lightning 的训练控制器，用于管理训练过程。  
  
  Trainer: The PyTorch Lightning controller for managing the training process.

- ClassificationTask: 从 core 导入，用于模型训练的主要任务类。
  
  ClassificationTask: Imported from core, the primary task class for model training.

- CIFAR100DataModule: 从 data 导入的数据加载模块。
  
  CIFAR100DataModule: Data loading module imported from data.

In [1]:
#| default_exp __main__

In [2]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [1]:
#| export
from namable_classify.core import ClassificationTask, ClassificationTaskConfig
config = ClassificationTaskConfig()
# config.learning_rate = 1e-1
# config.learning_rate = 1
config.learning_rate = 1e-3
# config.learning_rate = 3e-4
# config.learning_rate = 1e-6
config.dataset_config.batch_size = 64
cls_task = ClassificationTask(config)
cls_task.print_model_pretty()
import torch
# cls_task.cls_model = torch.compile(cls_task.cls_model, mode='reduce-overhead')
#  fullgraph=True

Seed set to 0


In [4]:
# #| export
# import lightning as L
# trainer = L.Trainer()
# from lightning.pytorch.tuner import Tuner
# tuner = Tuner(trainer)
# found_batch_size = tuner.scale_batch_size(cls_task, datamodule=cls_task.lit_data, 
#                                         #   mode='binsearch', 
#                                           mode='power', 
#                                           init_val=64)
# # found_batch_size, cls_task.lit_data.hparams.batch_size
# print(f"Found batch size: {found_batch_size}")

In [5]:
# #| export

# lr_finder = tuner.lr_find(cls_task, datamodule=cls_task.lit_data, 
#                         #   max_lr=1e-2
#                           )
# print(lr_finder.results)

# fig = lr_finder.plot(suggest=True)
# from matplotlib import pyplot as plt
# from namable_classify.utils import runs_figs_path
# plt.savefig(runs_figs_path/'lr_finder.png')
# # fig.show()
# new_lr = lr_finder.suggestion()
# # new_lr, cls_task.hparams.learning_rate
# print("New learning rate: ", new_lr)

In [2]:
#| export
import lightning as L
from namable_classify.utils import runs_path
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary, StochasticWeightAveraging, DeviceStatsMonitor
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

trainer = L.Trainer(default_root_dir=runs_path, enable_checkpointing=True, 
                    enable_model_summary=True, 
                    num_sanity_val_steps=2, # 防止 val 在训了好久train才发现崩溃
                    callbacks=[
                        # EarlyStopping(monitor="val_loss", mode="min")
                        EarlyStopping(monitor="val_acc1", mode="max", check_finite=True, 
                                      patience=5, 
                                      check_on_train_epoch_end=False,  # check on validation end
                                      verbose=True),
                        ModelSummary(max_depth=3),
                        # StochasticWeightAveraging(swa_lrs=1e-2), 
                        DeviceStatsMonitor(cpu_stats=True)
                               ]
                    
                    # , gradient_clip_val=1.0, gradient_clip_algorithm="value"
                    , logger=[TensorBoardLogger(save_dir=runs_path/"tensorboard"), CSVLogger(save_dir=runs_path)]
                    # , profiler="simple"
                    # , fast_dev_run=True
                    # limit_train_batches=10, limit_val_batches=5
                    # strategy="ddp", accelerator="gpu", devices=4
                    )
trainer.fit(cls_task, datamodule=cls_task.lit_data)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


NameError: name 'cls_task' is not defined

In [8]:
#| hide
import nbdev; nbdev.nbdev_export()