In [None]:
#| default_exp trainer

# Trainer
> Trainer API

In [None]:
#| export
import os
import os.path as osp
from argparse import ArgumentParser
import mmcv
import torch
from pytorch_lightning import Trainer, seed_everything
from lit_classifier.all import get_trainer
import shutil
from fastcore.script import *

def get_exp_by_file(exp_file):
    """
        Copy from https://github.com/Megvii-BaseDetection/YOLOX/blob/a5bb5ab12a61b8a25a5c3c11ae6f06397eb9b296/yolox/exp/build.py
    """
    try:
        import importlib
        import os
        import sys
        sys.path.append(os.path.dirname(exp_file))
        # import ipdb; ipdb.set_trace()
        current_exp = importlib.import_module(
            os.path.basename(exp_file).split(".")[0])
        current_exp = importlib.reload(current_exp)
        exp = current_exp.Exp()
        return exp
    except Exception:
        raise ImportError(
            "{} doesn't contains class named 'Exp'".format(exp_file))
        
@call_parse
def train(
    cfg_path:Param('Path to config'),
    devices: Param('GPUS indices', default=1, type=int),
):
    cfg = get_exp_by_file(cfg_path)
    print(cfg)
    exp_name = osp.basename(cfg_path).split('.')[0]

    data = cfg.get_data_loader()

    model = cfg.get_model(create_lr_scheduler_fn=cfg.get_lr_scheduler(), 
            create_optimizer_fn=cfg.get_optimizer())
    trainer = get_trainer(exp_name,
                          devices,
                          distributed=devices > 1,
                          max_epochs=max_epochs)

    # import ipdb; ipdb.set_trace()
    mmcv.mkdir_or_exist(trainer.log_dir)
    shutil.copy(cfg_path, osp.join(trainer.log_dir, osp.basename(cfg_path)))
    trainer.fit(model, data)


## Test

In [None]:
#| hide
train('../configs/00_mnist_vanila.py', devices=1)

2022-07-31 18:00:59.934 | INFO     | lit_classifier.lit_model:fn_schedule_cosine_with_warmpup_decay_timm:78 - num_cycles=3
No pretrained weights exist or were found for this model. Using random initialization.
2022-07-31 18:00:59.957 | INFO     | lit_classifier.lit_model:get_trainer:164 - Root log directory: lightning_logs/00_mnist_vanila/06


╒════════════════╤═══════════════════╕
│ keys           │ values            │
╞════════════════╪═══════════════════╡
│ seed           │ None              │
├────────────────┼───────────────────┤
│ output_dir     │ './YOLOX_outputs' │
├────────────────┼───────────────────┤
│ print_interval │ 100               │
├────────────────┼───────────────────┤
│ eval_interval  │ 10                │
├────────────────┼───────────────────┤
│ max_epochs     │ 20                │
├────────────────┼───────────────────┤
│ lr             │ 0.15              │
├────────────────┼───────────────────┤
│ batch_size     │ 128               │
├────────────────┼───────────────────┤
│ num_lr_cycles  │ 3                 │
╘════════════════╧═══════════════════╛


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


FileNotFoundError: [Errno 2] No such file or directory: '../configs/00_mnist_vanila.py'

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()
!pip install -e ../

Obtaining file:///home/anhvth/gitprojects/litclassifier
Installing collected packages: lit-classifier
  Attempting uninstall: lit-classifier
    Found existing installation: lit-classifier 0.0.2
    Uninstalling lit-classifier-0.0.2:
      Successfully uninstalled lit-classifier-0.0.2
  Running setup.py develop for lit-classifier
Successfully installed lit-classifier-0.0.2
