In [None]:
import numpy as np
from pathlib import Path
import os
import tempfile
from substra import Client
from substra.sdk.schemas import Permissions

from connectlib.dependency import Dependency
from connectlib.algorithms import Algo
from connectlib.remote import remote_data
from connectlib.schemas import StrategyName

# workaround because we can import tests in the CI (interfere with substra/tests), and we can't do relative import with nbmake
os.chdir("..")
import utils
import assets_factory
os.chdir("./dependency")

from test_dependency import TestLocalDependency

os.environ["DEBUG_SPAWNER"] = "subprocess"

In [None]:
"""Test that you can import a file"""
class MyAlgo(Algo):
    # this class must be within the test, otherwise the Docker will not find it correctly (ie because of the way
    # pytest calls it)
    @property
    def model(self):
        return None

    @property
    def strategies(self):
        return list(StrategyName)

    @remote_data
    def train(
        self,
        x: np.ndarray,
        y: np.ndarray,
        shared_state,
    ):
        from local_code_file import combine_strings

        some_strings = combine_strings("Foo", "Bar")
        assert some_strings == "FooBar"  # For flake8 purposes

        return dict(test=np.array(x), n_samples=len(x))

    @remote_data
    def predict(self, x: np.array, shared_state):
        return shared_state["test"]

    def load(self, path: Path):
        return self

    def save(self, path: Path):
        assert path.parent.exists()
        with path.open("w") as f:
            f.write("test")

with tempfile.TemporaryDirectory(dir=Path.cwd()) as tmp_dir:
    client = Client(debug=True)
    my_algo = MyAlgo()
    algo_deps = Dependency(
        pypi_dependencies=["pytest"],
        local_code=["local_code_file.py"],
    )
    testlocaldep = TestLocalDependency()
    algo_key = testlocaldep._register_algo(my_algo, algo_deps, client, tmp_dir)
    default_permissions = Permissions(public=True, authorized_ids=[])

    dataset_key = assets_factory.add_numpy_datasets(
        datasets_permissions=[default_permissions],
        clients=[client],
        tmp_folder=tmp_dir,
        msp_ids=['msp_id'],
    )

    sample_key = assets_factory.add_numpy_samples(
            contents=[np.zeros((1, 2))],
            dataset_keys=dataset_key,
            tmp_folder=tmp_dir,
            clients=[client],
        )

    composite_traintuple = testlocaldep._register_composite(algo_key, dataset_key[0], sample_key[0], client)
    utils.wait(client, composite_traintuple)