In [None]:
#| default_exp base

In [None]:
#| export
from pydantic import BaseModel as BasePydanticModel
import json
from pathlib import Path


In [None]:
#| hide
import shutil
import os
from fastcore.test import *

In [None]:
#| export
class BaseConfig(BasePydanticModel):

    def save(self, path):
        p = Path(path)
        if not str(p).endswith('.json'):
            raise ValueError(f"Path must end with `.json`, but got: {p}")
        if not p.parent.exists():
            p.parent.mkdir(parents=True)
        with open(path, 'w') as f:
            json.dump(self.dict(), f, indent=4)
    
    @classmethod
    def load_from_json(cls, path):
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"File not found: {p}")
        with open(path, 'r') as f:
            return cls(**json.load(f))

In [None]:
class ConfigTest(BaseConfig):
    a: int = 1
    b: str = 'b'
    c: float = 3.14

conf = ConfigTest()
conf.save('test.json')
conf2 = ConfigTest.load_from_json('test.json')
assert conf == conf2
# remove test.json
os.remove('test.json')

conf = ConfigTest()
conf.save('tmp/test.json')
conf2 = ConfigTest.load_from_json('tmp/test.json')
assert conf == conf2
shutil.rmtree('tmp')

test_fail(lambda: conf.save('test'), contains="Path must end with `.json`,")
test_fail(lambda: ConfigTest.load_from_json('test.json'), contains="File not found")


In [None]:
#| export
class BaseModule:
    def __init__(self, config, *, name=None):
        self.config = config
        self._name = name

    @property
    def name(self):
        return self._name or self.__class__.__name__
    
    def save(self, path):
        raise NotImplementedError

    def load_from_path(self, path):
        raise NotImplementedError

In [None]:
class TestModule(BaseModule):
    def save(self, path):
        self.config.save(Path(path) / 'config.json')

    def load_from_path(self, path):
        self.config = ConfigTest.load_from_json(Path(path) / 'config.json')

conf = ConfigTest()
module = TestModule(conf)
assert module.name == 'TestModule'
module.save('tmp')
module.load_from_path('tmp')
assert module.config == conf
shutil.rmtree('tmp')

In [None]:
#| export
class PredFnMixedin:
    def pred_fn(self, x):
        raise NotImplementedError

In [None]:
#| export
class TrainableMixedin:
    @property
    def is_trained(self) -> bool:
        raise NotImplementedError
    
    def train(self, data, **kwargs):
        raise NotImplementedError