In [2]:
#| default_exp trainer

# Trainer
> Trainer API

In [13]:
#| export
import os
import os.path as osp
from argparse import ArgumentParser
import mmcv
import torch
from pytorch_lightning import Trainer, seed_everything
from ple.all import get_trainer
import shutil
from fastcore.script import *
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from loguru import logger

def get_rank():
    try:
        rank = dist.get_rank()
    except:
        rank = 0
    return rank

def get_exp_by_file(exp_file):
    """
        Params:
        exp_file: Path to exp
        
    """
    try:
        import importlib
        import os
        import sys
        sys.path.append(os.path.dirname(exp_file))
        # import ipdb; ipdb.set_trace()
        current_exp = importlib.import_module(
            os.path.basename(exp_file).split(".")[0])
        current_exp = importlib.reload(current_exp)
        exp = current_exp.Exp()
        return exp
    except Exception:
        raise ImportError(
            "{} doesn't contains class named 'Exp'".format(exp_file))
        
@call_parse
def train(
    cfg_path:Param('Path to config'),
    devices: Param('GPUS indices', default=1, type=int),
    accelerator: Param('cpu or gpu', default='gpu', type=str),
):

    cfg = get_exp_by_file(cfg_path)
    print(cfg)
    exp_name = osp.basename(cfg_path).split('.')[0]

    data = cfg.get_data_loader()

    model = cfg.get_model(create_lr_scheduler_fn=cfg.get_lr_scheduler(), 
            create_optimizer_fn=cfg.get_optimizer())
    trainer = get_trainer(exp_name,
                          devices,
                          max_epochs=cfg.max_epochs, 
                          trainer_kwargs=dict(
                              accelerator=accelerator,
                          ))
    try:
        trainer.fit(model, data)
    except Exception as e:
        import traceback
        traceback.print_exc()
    finally:
        if get_rank() == 0:
            out_path = osp.join(trainer.log_dir, osp.basename(cfg_path))
            logger.info('cp {} {}', cfg_path, out_path)
            shutil.copy(cfg_path, out_path)
        


## Test

In [14]:
#| hide
# train('./configs/00_mnist_vanila.py', devices=1)

In [15]:
#| hide
from nbdev import nbdev_export
nbdev_export()