In [3]:
from typing import Optional

import torch
from transformers import PreTrainedModel, get_scheduler

from src import datasets, models, retrievers
from src.configs import (
    DataLoaderConfigs,
    LRSchedulerConfigs,
    ModelConfigs,
    OptimizerConfigs,
    RetrieverConfigs,
)


def get_optimizer(
    optimizer_config: OptimizerConfigs, model: PreTrainedModel
) -> torch.optim.Optimizer:
    optimizer_class = getattr(torch.optim, optimizer_config.name)
    return optimizer_class(model.parameters(), **optimizer_config.configs)


def get_lr_scheduler(
    lr_scheduler_config: LRSchedulerConfigs,
    optimizer: torch.optim.Optimizer,
    num_training_steps: int,
) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
    if lr_scheduler_config.name is None:
        return None
    if hasattr(torch.optim.lr_scheduler, lr_scheduler_config.name):
        lr_scheduler_class = getattr(torch.optim.lr_scheduler, lr_scheduler_config.name)
        return lr_scheduler_class(optimizer, **lr_scheduler_config.configs)

    return get_scheduler(
        name=lr_scheduler_config.name,
        optimizer=optimizer,
        num_training_steps=num_training_steps,
        **lr_scheduler_config.configs,
    )


def get_dataset(dataloader_config: DataLoaderConfigs):
    return getattr(datasets, dataloader_config.dataset_loader)


def get_pipeline(model_config: ModelConfigs):
    pipeline = getattr(models, model_config.pipeline)
    return pipeline


def get_retriever(retriever_config: RetrieverConfigs):
    pipeline = getattr(retrievers, retriever_config.name)
    return pipeline


ModuleNotFoundError: No module named 'src'

In [4]:
pip install src

Collecting src
  Using cached src-0.0.7.zip (6.3 kB)
Building wheels for collected packages: src
  Building wheel for src (setup.py) ... [?25lerror
[31m  ERROR: Command errored out with exit status 1:
   command: /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-77696dww/src/setup.py'"'"'; __file__='"'"'/tmp/pip-install-77696dww/src/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-awmwtnmr
       cwd: /tmp/pip-install-77696dww/src/
  Complete output (41 lines):
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build/lib
  creating build/lib/src
  copying src/__init__.py -> build/lib/src
  running egg_info
  writing src.egg-info/PKG-INFO
  writing dependency_links to src.egg-info/dependency_links.txt
  writing entry points to src.egg-info/entry_points.tx