Skip to content

0.36.0

Compare
Choose a tag to compare
@Milouu Milouu released this 15 May 14:57
· 106 commits to main since this release
9289ae2

Fixed

  • Close issue #114. Large batch size are set to the number of samples in predict for NR and FedPCA. (#115)

Changed

  • BREAKING: Metrics are now given as metric_functions and not as metric_key. The functions given as metric functions to test data nodes are automatically registered in a new Substra function by SubstraFL. (#117).
    The new argument of the TestDataNode class metric_functions replaces the metric_keys one and accepts a dictionary (using the key as the identifier of the function given as value), a list of functions or directly a function if there is only one metric to compute (function.__name__ is then used as identifier).
    Installed dependencies are the algo_dependencies passed to execute_experiment, and permissions are the same as the predict function.

    From a user point of view, the metric registration changes from:

    def accuracy(datasamples, predictions_path):
      y_true = datasamples["labels"]
      y_pred = np.load(predictions_path)
    
      return accuracy_score(y_true, np.argmax(y_pred, axis=1))
    
    metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])
    
    permissions_metric = Permissions(public=False, authorized_ids=DATA_PROVIDER_ORGS_ID)
    
    metric_key = add_metric(
        client=client,
        metric_function=accuracy,
        permissions=permissions_metric,
        dependencies=metric_deps,
    )
    
    test_data_nodes = [
        TestDataNode(
            organization_id=org_id,
            data_manager_key=dataset_keys[org_id],
            test_data_sample_keys=[test_datasample_keys[org_id]],
            metric_keys=[metric_key],
        )
        for org_id in DATA_PROVIDER_ORGS_ID
    ]

    to:

    def accuracy(datasamples, predictions_path):
      y_true = datasamples["labels"]
      y_pred = np.load(predictions_path)
    
      return accuracy_score(y_true, np.argmax(y_pred, axis=1))
    
    test_data_nodes = [
        TestDataNode(
            organization_id=org_id,
            data_manager_key=dataset_keys[org_id],
            test_data_sample_keys=[test_datasample_keys[org_id]],
            metric_functions={"Accuracy": accuracy},
        )
        for org_id in DATA_PROVIDER_ORGS_ID
    ]
  • Enforce kwargs for user facing function with more than 3 parameters (#109)

  • Remove references to composite. Replace by train_task. (#108)

Added

  • Add the Federated Principal Component Analysis strategy (#97)