In [None]:
import numpy as np
from pathlib import Path
import os
import tempfile
import substra
from substra.sdk.schemas import Permissions

import shutil
from substrafl.dependency import Dependency
from substrafl.algorithms import Algo
from substrafl.remote import remote
from substrafl.remote import remote_data
from substrafl.strategies.schemas import StrategyName

# workaround because we can import tests in the CI (interfere with substra/tests), and we can't do relative import with nbmake
os.chdir("..")
import utils
import assets_factory
os.chdir("./dependency")

from test_dependency import TestLocalDependency

In [None]:
"""Test that you can import a file"""
class MyAlgo(Algo):
    # this class must be within the test, otherwise the Docker will not find it correctly (ie because of the way
    # pytest calls it)
    @property
    def model(self):
        return None

    @property
    def strategies(self):
        return list(StrategyName)

    @remote
    @staticmethod
    def initialize(shared_states):
        return

    @remote_data
    def train(
        self,
        datasamples: np.ndarray,
        shared_state,
    ):
        from local_code_file import combine_strings

        some_strings = combine_strings("Foo", "Bar")
        assert some_strings == "FooBar"  # For flake8 purposes

        return dict(test=np.array(datasamples), n_samples=len(datasamples))

    @remote_data
    def predict(self, datasamples: np.array, shared_state):
        return shared_state["test"]

    def load_local_state(self, path: Path):
        return self

    def save_local_state(self, path: Path):
        assert path.parent.exists()
        with path.open("w") as f:
            f.write("test")

with tempfile.TemporaryDirectory(dir=Path.cwd()) as tmp_dir:
    client = substra.Client(backend_type=substra.BackendType.LOCAL_SUBPROCESS)
    my_algo = MyAlgo()
    algo_deps = Dependency(
        pypi_dependencies=["pytest"],
        local_code=["local_code_file.py"],
    )
    testlocaldep = TestLocalDependency()
    function_key = testlocaldep._register_function(my_algo, algo_deps, client, tmp_dir)
    default_permissions = Permissions(public=True, authorized_ids=[])

    dataset_key = assets_factory.add_numpy_datasets(
        datasets_permissions=[default_permissions],
        clients=[client],
        tmp_folder=tmp_dir,
    )

    sample_key = assets_factory.add_numpy_samples(
            contents=[np.zeros((1, 2))],
            dataset_keys=dataset_key,
            tmp_folder=tmp_dir,
            clients=[client],
        )

    train_task = testlocaldep._register_train_task(function_key, dataset_key[0], sample_key[0], client)
    utils.wait(client, train_task)