In [None]:
# | default_exp base_exp


# Base_exp
| Base_exp API

In [12]:
# | export

import ast
import pprint
from abc import ABCMeta, abstractmethod
from typing import Dict
from tabulate import tabulate

import torch
from torch.nn import Module
import pytorch_lightning as pl
from torch.optim.lr_scheduler import _LRScheduler
from ple.all import get_trainer
import os.path as osp

class BaseExp(metaclass=ABCMeta):
    """Basic class for any experiment."""

    def __init__(self):
        self.accelerator = 'gpu'
        
    @abstractmethod
    def get_model(self) -> Module:
        pass

    @abstractmethod
    def get_data_loader(
        self,
    ) -> pl.LightningDataModule:
        pass

    @abstractmethod
    def get_optimizer() -> torch.optim.Optimizer:
        pass

    @abstractmethod
    def get_lr_scheduler(
        self, lr: float, iters_per_epoch: int, **kwargs
    ) -> _LRScheduler:
        pass

    def __repr__(self):
        table_header = ["keys", "values"]
        exp_table = [
            (str(k), pprint.pformat(v))
            for k, v in vars(self).items()
            if not k.startswith("_")
        ]
        return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
    def merge(self, cfg_list):
        assert len(cfg_list) % 2 == 0
        for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
            # only update value with same key
            if hasattr(self, k):
                src_value = getattr(self, k)
                src_type = type(src_value)
                if src_value is not None and src_type != type(v):
                    try:
                        v = src_type(v)
                    except Exception:
                        v = ast.literal_eval(v)
                print(f'Set {k}={v}')
                setattr(self, k, v)
        
    def get_trainer(self, devices:int):
        return get_trainer(self.exp_name,
                               devices,
                              max_epochs=self.max_epochs, 
                              trainer_kwargs=dict(
                                  accelerator=self.accelerator,
                              )
                            )

In [13]:
#| hide
from nbdev import nbdev_export
nbdev_export()