## Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'researcher'
director_node_fqdn = 'openfl-director'
director_port = 4444

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=True, cert_chain='cert/root_ca.crt',
    api_cert='cert/researcher.crt', api_private_key='cert/researcher.key'
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

## Model

In [None]:
from typing import List, Union
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import log_loss
import random

class SGDCls:
    def __init__(self, n_features: int, n_classes: int) -> None:
        self.n_features = n_features
        self.n_classes = n_classes
        if n_classes == 2:
            # TODO: implement binary cls.
            pass
        else:
            self.weights = np.ones((n_classes, n_features + 1))
        self.estimator = SGDClassifier(loss='log')
        
    def predict(self, x: np.ndarray) -> float:
        return self.estimator.predict(x)
    
    def logloss(self, x: np.ndarray, y: np.ndarray) -> float:
        return log_loss(y, self.estimator.predict_proba(x))
    
    def fit(self, x: np.ndarray, y: np.ndarray, n_epochs: int) -> None:
        self.coef_ = self.weights[:,:self.n_features]
        self.intercept_ = self.weights[:,self.n_features]
        for i in range(n_epochs):
            self.estimator.partial_fit(x, y, classes=np.array(list(range(self.n_classes))))
        self.estimator.densify()
        self.weights = np.concatenate((self.coef_, self.intercept_.reshape(
            self.intercept_.shape[0], 1)), axis=1)
        

## Data

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

class SGDClsDataSet(DataInterface):
    def __init__(self, **kwargs):
        """Initialize DataLoader."""
        self.kwargs = kwargs
        pass

    @property
    def shard_descriptor(self):
        """Return shard descriptor."""
        return self._shard_descriptor
    
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self.train_set = shard_descriptor.get_dataset("train")
        self.val_set = shard_descriptor.get_dataset("val")

    def get_train_loader(self, **kwargs):
        """Output of this method will be provided to tasks with optimizer in contract."""
        return self.train_set

    def get_valid_loader(self, **kwargs):
        """Output of this method will be provided to tasks without optimizer in contract."""
        return self.val_set

    def get_train_data_size(self):
        """Information for aggregation."""
        return len(self.train_set)

    def get_valid_data_size(self):
        """Information for aggregation."""
        return len(self.val_set)
    
    
sgdcls_dataset = SGDClsDataSet()

## Model Interface

In [None]:
framework_adapter = 'custom_adapter.CustomFrameworkAdapter'
fed_model = SGDCls(4, 3)
MI = ModelInterface(model=fed_model, optimizer=None, framework_plugin=framework_adapter)

## Tasks

In [None]:
TI = TaskInterface()

@TI.add_kwargs(**{'epochs': 10})
@TI.register_fl_task(model='my_model', data_loader='train_data', \
                     device='device', optimizer='optimizer')
def train(my_model, train_data, optimizer, device, epochs):
    x, y = train_data[:,:-1], train_data[:,-1].astype('int')
    my_model.fit(x, y, epochs)
    return {'train_logloss': my_model.logloss(x, y)}

@TI.register_fl_task(model='my_model', data_loader='val_data', device='device')
def validate(my_model, val_data, device):
    x, y = val_data[:,:-1], val_data[:,-1].astype('int')
    return {'validation_logloss': my_model.logloss(x, y)}

## Run

In [None]:
experiment_name = 'sgd_classification_experiment_0'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=sgdcls_dataset,
                    rounds_to_train=10)

In [None]:
fl_experiment.stream_metrics()

In [None]:
%%script /bin/bash --bg
tensorboard --host $(hostname --all-fqdns | awk '{print $1}') --logdir logs

In [None]:
!kill $(pidof $(which python) | awk '{print $1}')