In [1]:
# hide
# default_exp core.compose
import os
from nbdev.showdoc import *
if not os.path.exists('settings.ini'):
    os.chdir('..')

# Compose transforms

> Classes and utilities for composed transforms

In [2]:
#export
import pandas as pd

from block_types.core.block_types import Component, PandasComponent, SamplingComponent
from block_types.core.data_conversion import PandasConverter
from block_types.core.utils import PandasIO

In [3]:
#export
class MultiComponent (SamplingComponent):
    """
    Component containing a list of components inside.
    
    The list must contain at least one component. 
    
    See `Pipeline` class.
    """
    def __init__ (self, separate_labels = False, **kwargs):
        """Assigns attributes and calls parent constructor.

        Parameters
        ----------
        separate_labels: bool, optional
            whether or not the fit method receives the labels in a separate `y` vector 
            or in the same input `X`, as an additional variable. See description of 
            Pipeline class for more details.
        """
        if not hasattr (self, 'components'):
            self.components = []
        if not hasattr (self, 'finalized_component_list'):
            self.finalized_component_list = False
        
        # we need to call super().__init__() *after* having creating the `components` field,
        # since the constructor of Component calls a method that is overriden in Pipeline, 
        # and this method makes use of the mentioned `components` field
        super().__init__ (separate_labels = separate_labels, 
                          **kwargs)

        self.set_training_data_flag(False)
        
    
    def register_components (self, *ms):
        """
        Registering component in `self.components` list.
        
        Every time that a new component is set as an attribute of the pipeline,
        this component is added to the list `self.components`. Same 
        mechanism as the one used by pytorch's `nn.Module`
        """
        if not hasattr(self, 'components'):
            self.components = []
            self.finalized_component_list = False
        if not self.finalized_component_list:
            self.components += ms
    
    def __setattr__(self, k, v):
        """
        See register_components
        """
        super().__setattr__(k, v)
        
        if isinstance(v, Component):
            self.register_components(v)
            
    def add_component (self, component):
        if not hasattr(self, 'finalized_component_list'):
            self.finalized_component_list = False
        finalized_component_list = self.finalized_component_list
        self.finalized_component_list = False
        self.register_components(component)
        self.finalized_component_list = finalized_component_list
        
    def set_components (self, *components):
        self.components = components
        self.finalized_component_list = True
        
    def construct_diagram (self, training_data_flag=None, include_url=False, port=4000, project='block_types'):
        """
        Construct diagram of the pipeline components, data flow and dimensionality.
        
        By default, we use test data to show the number of observations 
        in the output of each component. This can be changed passing 
        `training_data_flag=True`
        """
        training_data_flag = self.get_training_data_flag (training_data_flag)

        if include_url:
            base_url = f'http://localhost:{port}/{project}'
        else:
            URL = ''

        node_name = 'data'
        output = 'train / test'

        f = Digraph('G', filename='fsm2.svg')
        f.attr('node', shape='circle')

        f.node(node_name)

        f.attr('node', shape='box')
        for component in self.components:
            last_node_name = node_name
            last_output = output
            node_name = component.model_plotter.get_node_name()
            if include_url:
                URL = f'{base_url}/{component.model_plotter.get_module_path()}.html#{node_name}'
            f.node(node_name, URL=URL)
            f.edge(last_node_name, node_name, label=last_output)
            output = component.model_plotter.get_edge_name(training_data_flag=training_data_flag)

        last_node_name = node_name
        node_name = 'output'
        f.attr('node', shape='circle')
        f.edge(last_node_name, node_name, label=output)

        return f

    def show_result_statistics (self, training_data_flag=None):
        """
        Show statistics about results obtained by each component. 
        
        By default, this is shown on test data, although this can change setting 
        `training_data_flag=True`
        """
        training_data_flag = self.get_training_data_flag (training_data_flag)

        for component in self.components:
            component.show_result_statistics(training_data_flag=training_data_flag)

    def show_summary (self, training_data_flag=None):
        """
        Show list of pipeline components, data flow and dimensionality.
        
        By default, we use test data to show the number of observations 
        in the output of each component. This can be changed passing 
        `training_data_flag=True`
        """
        training_data_flag = self.get_training_data_flag (training_data_flag)

        node_name = 'data'
        output = 'train / test'

        for i, component in enumerate(self.components):
            node_name = component.model_plotter.get_node_name()
            output = component.model_plotter.get_edge_name(training_data_flag=training_data_flag)
            print (f'{"-"*100}')
            print (f'{i}: {node_name} => {output}')


    def get_training_data_flag (self, training_data_flag=None):
        if training_data_flag is None:
            if self.data_io.training_data_flag is not None:
                training_data_flag = self.data_io.training_data_flag
            else:
                training_data_flag = False

        return training_data_flag

    def assert_equal (self, path_reference_results, assert_equal_func=pd.testing.assert_frame_equal, **kwargs):
        """Compare results stored in current run against reference results stored in given path."""

        for component in self.components:
            component.assert_equal (path_reference_results, assert_equal_func=assert_equal_func, **kwargs)
        self.logger.info ('both pipelines give the same results')
        print ('both pipelines give the same results')
        
    def load_estimator (self):
        for component in self.components:
            component.load_estimator ()

    # *************************
    # setters
    # *************************
    def set_training_data_flag (self, training_data_flag):
        super().set_training_data_flag (training_data_flag)
        for component in self.components:
            component.set_training_data_flag (training_data_flag)

    def set_save_result_flag_test (self, save_result_flag_test):
        super().set_save_result_flag_test (save_result_flag_test)
        for component in self.components:
            component.set_save_result_flag_test (save_result_flag_test)

    def set_save_result_flag_training (self, save_result_flag_training):
        super().set_save_result_flag_training (save_result_flag_training)
        for component in self.components:
            component.set_save_result_flag_training (save_result_flag_training)

    def set_save_result_flag (self, save_result_flag):
        super().set_save_result_flag (save_result_flag)
        for component in self.components:
            component.set_save_result_flag (save_result_flag)

    def set_overwrite (self, overwrite):
        super().set_overwrite (overwrite)
        for component in self.components:
            component.set_overwrite (overwrite)

    def set_save_fitting (self, save_fitting):
        super().set_save_fitting (save_fitting)
        for component in self.components:
            component.set_save_fitting (save_fitting)

In [4]:
#export
class Pipeline (MultiComponent):
    """
    Pipeline composed of a list of components that run sequentially.
    
    During training, the components of the list are trained one after the other, 
    where one component is fed the result of transforming the data with the list 
    of components located before in the pipeline.
    
    The `Pipeline` class is a subclass of `SamplingComponent`, which itself is a 
    subclass of `Component`. This provides the functionality of `Component` 
    to any implemented pipeline, such as logging the messages, loading / saving the 
    results, and convert the data format so that it can work as part of other 
    pipelines with potentially other data formats.
    
    Being a subclass of `SamplingComponent`, the `transform` method 
    receives an input data  `X` that contains both data and labels. 
    
    Furthermore, the Pipeline constructor sets `separate_labels=False` by default,
    which means that the `fit` method also receives an input data `X` that contains 
    not only data but also labels. This is necessary because some of the components in 
    the pipeline might be of class `SamplingComponent`, and such components 
    need the input data `X` to contain labels when calling `transform` (and note that 
    this method is called when calling `fit` on a pipeline, since we do `fit_transform`
    on all the components except for the last one)
    """
    def __init__ (self, **kwargs):
        """Assigns attributes and calls parent constructor.

        Parameters
        ----------
        separate_labels: bool, optional
            whether or not the fit method receives the labels in a separate `y` vector 
            or in the same input `X`, as an additional variable. See description of 
            Pipeline class for more details.
        """
        
        super().__init__ (**kwargs)
        
    def _fit (self, X, y=None):
        """
        Fit components of the pipeline, given data X and labels y.
        
        By default, y will be None, and the labels are part of `X`, as a variable.
        """
        X = self._fit_apply_except_last (X, y)
        self.components[-1].fit (X, y)
    
    def _fit_apply (self, X, y=None, **kwargs):
        X = self._fit_apply_except_last (X, y, **kwargs)
        return self.components[-1].fit_apply (X, y, **kwargs)

    def _fit_apply_except_last (self, X, y, **kwargs):
        #self.set_training_data_flag (True)
        for component in self.components[:-1]:
            X = component.fit_apply (X, y, **kwargs)
        return X
    
    def _apply (self, X):
        """Transform data with components of pipeline, and predict labels with last component. 
        
        In the current implementation, we consider prediction a form of mapping, 
        and therefore a special type of transformation."""
        #self.set_training_data_flag (False)
        for component in self.components:
            X = component.transform (X)

        return X

In [5]:
if False:
    show_doc (Pipeline, title_level=3)
    show_doc (Pipeline.__init__, name='__init__', title_level=4)
    show_doc (Pipeline.construct_diagram, name='construct_diagram', title_level=4)
    show_doc (Pipeline.show_summary, name='show_summary', title_level=4)
    show_doc (Pipeline.show_result_statistics, name='show_result_statistics', title_level=4)
    show_doc (Pipeline.assert_equal, name='assert_equal', title_level=4)

#### `fit_apply` method

In [6]:
# test `fit_apply` method
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.utils import Bunch
from block_types.core.block_types import PickleSaverComponent
from block_types.utils.utils import remove_previous_results

# Transform1: custom Transform
class Transform1 (Component):
    
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        self.estimator= Bunch(sum = 1)
        
    def _fit (self, X, y=None):
        self.estimator.sum = X.sum(axis=0)
    
    def _apply (self, x):
        return x*1000 + self.estimator.sum
    
class Transform2 (Component):
    
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        self.estimator= Bunch(maxim = 1)
        
    def _fit (self, X, y=None):
        self.estimator.maxim = X.max(axis=0)
    
    def _apply (self, x):
        return x*100 + self.estimator.maxim

class NewPipeline (Pipeline):
    
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        
        # custom transform
        self.tr1 = Transform1(**kwargs) 
        
        # slklearn transform
        self.tr2 = Transform2(**kwargs) 

pipeline = NewPipeline()
x = np.array([3,4,5])
r1 = pipeline.fit_apply (x.reshape(-1,1))
print (r1)

x1 = x * 1000 + sum(x)
x2 = x1 * 100 + max(x1)
assert (r1.ravel()==x2).all()

fitting new_pipeline
fitting transform1
applying transform1 transform
fitting transform2
applying transform2 transform


[[306212]
 [406212]
 [506212]]


In [7]:
class NewMulti (MultiComponent):
    
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        
        # custom transform
        self.tr1 = Transform1(**kwargs) 
        
        # slklearn transform
        self.tr2 = Transform2(**kwargs) 
        
    def _fit (self, X, y=None):
        self.tr1.fit (X)
        self.tr2.fit (X)
        
    def _apply (self, X, y=None):
        X1=self.tr1.apply (X)
        X2=self.tr2.apply (X)
        return X1+X2

new_multi = NewMulti()
r2 = new_multi.fit_apply (x)
print (r2)
x2b = 100 * x + max(x)
assert (r2.ravel()==(x1 + x2b)).all()

fitting new_multi
fitting transform1
fitting transform2
applying new_multi transform
applying transform1 transform
applying transform2 transform


[3317 4417 5517]


#### adding new components to pipeline

In [8]:
# test automatic creation of pipeline components
from sklearn.preprocessing import FunctionTransformer

# 1. by setting components as attributes:
class NewPipeline(Pipeline):
    def __init__ (self, **kwargs):
        super().__init__(**kwargs)
        self.tr1 = Component(FunctionTransformer (lambda x: x+1))
        self.tr2 = Component(FunctionTransformer (lambda x: x*2))
pipeline = NewPipeline()
result = pipeline.transform (3)
print (result)
assert result == 8

applying new_pipeline transform
applying function_transformer transform
applying function_transformer transform


8


In [9]:
#2. by using `set_components`
class NewPipeline(Pipeline):
    def __init__ (self, **kwargs):
        super().__init__(**kwargs)
        tr1 = Component(FunctionTransformer (lambda x: x+1))
        tr2 = Component(FunctionTransformer (lambda x: x*2))
        self.set_components (tr1, tr2)
        
        # the following transform is not added to the pipeline component list:
        self.tr3 = Component(FunctionTransformer (lambda x: x+1))
        
        # The reason is that once set_components is called, the component list 
        # is frozen and inmutable setting new components by attribute doesn't 
        # result in adding them to the component list
        
pipeline = NewPipeline()
result = pipeline.transform (3)

assert result == 8
assert len(pipeline.components) == 2
print (result, len(pipeline.components))

applying new_pipeline transform
applying function_transformer transform
applying function_transformer transform


8 2


In [10]:
#3. after calling `set_components()`, we can add new components with `add_component()`
class NewPipeline(Pipeline):
    def __init__ (self, **kwargs):
        super().__init__(**kwargs)
        tr1 = Component(FunctionTransformer (lambda x: x+1))
        tr2 = Component(FunctionTransformer (lambda x: x*2))
        self.set_components (tr1, tr2)
        
        tr3 = Component(FunctionTransformer (lambda x: x+2))
        self.add_component(tr3)
        
pipeline = NewPipeline()
result = pipeline.transform (3)

assert result == 10
assert len(pipeline.components) == 3
print (result, len(pipeline.components))

applying new_pipeline transform
applying function_transformer transform
applying function_transformer transform
applying function_transformer transform


10 3


#### `load_estimator` method

In [11]:
# test `load_estimator` method
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.utils import Bunch
from block_types.core.block_types import PickleSaverComponent
from block_types.utils.utils import remove_previous_results

# Transform1: custom Transform
class Transform1 (PickleSaverComponent):
    
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        self.estimator= Bunch(inv_c = 1)
        
    def _fit (self, X, y=None):
        self.estimator.inv_c = X.ravel()[0]
    
    def _apply (self, x):
        return x / self.estimator.inv_c

class NewPipeline (Pipeline):
    
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        
        # custom transform
        self.tr1 = Transform1(**kwargs) 
        
        # slklearn transform
        self.tr2 = PickleSaverComponent(StandardScaler(), **kwargs)
        
    def _fit (self, X, y=None):
        self.tr1.fit (X)
        self.tr2.fit (X)

# remove any previously stored files
remove_previous_results (path_results='results')

pipeline = NewPipeline(path_results='results', save_test_result=False)
pipeline.fit (np.array([3,4,5]).reshape(-1,1))
result1 = pipeline.transform (np.array([300,400,500]).reshape(-1,1))
print (pipeline.tr2.estimator.mean_)

del pipeline 
pipeline = NewPipeline(path_results='results', save_test_result=False)
pipeline.load_estimator ()
print (pipeline.tr2.estimator.mean_)
result2 = pipeline.transform (np.array([300,400,500]).reshape(-1,1))

np.testing.assert_array_equal (result1, result2)

# remove stored files resulting from running the current test
remove_previous_results (path_results='results')

fitting new_pipeline
fitting transform1
fitting standard_scaler
applying new_pipeline transform
applying transform1 transform
applying standard_scaler transform
loading from /home/jcidatascience/jaume/workspace/remote/temp/block-types/results/transform1_estimator.pk
loading from /home/jcidatascience/jaume/workspace/remote/temp/block-types/results/standard_scaler_estimator.pk
applying new_pipeline transform
applying transform1 transform
applying standard_scaler transform


[4.]
[4.]


#### make_pipeline

In [12]:
# export
def make_pipeline(*components, cls=Pipeline, **kwargs):
    """Create `Pipeline` object of class `cls`, given `components` list."""
    pipeline = cls (**kwargs)
    pipeline.components = list(components)
    return pipeline

In [13]:
tr1 = Component(FunctionTransformer (lambda x: x+1))
tr2 = Component(FunctionTransformer (lambda x: x*2))
pipeline = make_pipeline (tr1, tr2)
result = pipeline.transform (3)

print (result)
assert result == 8

applying pipeline transform
applying function_transformer transform
applying function_transformer transform


8


In [14]:
# export
def pipeline_factory (pipeline_class, **kwargs):
    """Creates a pipeline object given its class `pipeline_class`
    
    Parameters
    ----------
    pipeline_class : class or str
        Name of the pipeline class used for creating the object. 
        This can be either of type string or class.
    """
    if type(pipeline_class) is str:
        Pipeline = eval(pipeline_class)
    elif type(pipeline_class) is type:
        Pipeline = pipeline_class
    else:
        raise ValueError (f'pipeline_class needs to be either string or class, we got {pipeline_class}')

    return Pipeline (**kwargs)

In [15]:
#export
class PandasPipeline (Pipeline):
    """
    Pipeline that saves results in parquet format, and preserves DataFrame format.
    
    See `Pipeline` class for an explanation of using `separate_labels=False`
    """
    def __init__ (self, 
                  data_converter=None,
                  data_io=None,
                  separate_labels=False,
                  **kwargs):
        if data_converter is None:
            data_converter = PandasConverter (separate_labels=separate_labels,
                                              **kwargs)
        if data_io is None:
            data_io = PandasIO (**kwargs)
        super().__init__ (self, 
                          data_converter=data_converter,
                          data_io=data_io,
                          **kwargs)

In [16]:
#export
class ColumnSelector (Component):
    def __init__ (self, 
                  columns=[],
                  **kwargs):
        super().__init__ (**kwargs)
        self.columns = columns
    
    def _apply (self, df):
        return df[self.columns]

In [17]:
df = pd.DataFrame ({'x1': list(range(5)),
                    'x2': list(range(5,10)),
                    'x3': list(range(15,20)),
                    'x4': list(range(25,30))
                   })
dfr = ColumnSelector(columns=['x2','x4']).transform(df)
assert (dfr==df[['x2','x4']]).all().all()

applying column_selector transform


In [18]:
#export
class Concat (Component):
    def __init__ (self, 
                  **kwargs):
        super().__init__ (**kwargs)
        
    def _apply (self, *dfs):
        return pd.concat(list(dfs), axis=1)

In [19]:
# export
class _BaseColumnTransformer (MultiComponent):
    def __init__ (self, **kwargs):
        super().__init__ (**kwargs)
        self.concat = Concat (**kwargs)
    
    def _fit (self, df, y=None):
        for component in self.components:
            component.fit (df)
        return self
    
    def _apply (self, df):
        dfs = []
        for component in self.components:
            dfs.append (component.transform (df))
        df_result = self.concat.transform (*dfs)
        return df_result
    
class ColumnTransformer (_BaseColumnTransformer):
    def __init__ (self, *transformers, **kwargs):
        self.components = make_column_transformer_pipelines (*transformers, **kwargs)
        super().__init__ (**kwargs)

In [20]:
# export
class Identity (Component):
    def __init__ (self, **kwargs):
        super ().__init__ (**kwargs)
        
    def _apply (self, X):
        return X
    
def make_column_transformer_pipelines (*transformers, **kwargs):
    pipelines = []
    for name, transformer, columns in transformers:
        if (type(transformer) is str) and transformer == 'passthrough':
            transformer = Identity (**kwargs)
        pipeline = make_pipeline(ColumnSelector(columns, **kwargs), 
                                 transformer, 
                                 name = name,
                                 **kwargs)
        pipelines.append (pipeline)
    
    return pipelines


def make_column_transformer (*transformers, **kwargs):
    transformers_with_name = []
    for transformer, columns in transformers:
        columns_name = ''.join([x[0] for x in columns])
        if len(columns_name) > 5:
            columns_name = columns_name[:5]
        if (type(transformer) is str) and transformer == 'passthrough':
            transformer_name = 'pass'
        elif hasattr(transformer, 'name'):
            transformer_name = transformer.name
        else:
            transformer_name = transformer.__class__.__name__
        name = f'{transformer_name}_{columns_name}'
        transformers_with_name.append ((name, transformer, columns))
    
    pipelines = make_column_transformer_pipelines (*transformers_with_name, **kwargs)
    column_transformer = _BaseColumnTransformer ()
    column_transformer.components = pipelines
    return column_transformer
    

In [21]:
import pandas as pd
from sklearn.preprocessing import FunctionTransformer

df = pd.DataFrame ({'cont1': list(range(5)),
                    'cont2': list(range(5,10)),
                    'cont3': list(range(15,20)),
                    'cont4': list(range(25,30)),
                    'cat_1': list([1,2,3,2,1]),
                    'cat_2': list([0,1,1,0,0])
                    })

tr1 = Component(FunctionTransformer (lambda x: x+1), name='tr1')
tr2 = PandasComponent(FunctionTransformer (lambda x: x*2), transformed_columns=['cont2_bis','cat_1'], name='tr2')

column_transformer = make_column_transformer (
    (tr1, ['cont2', 'cont4']),
    (tr2, ['cont2', 'cat_1'])
)
dfr = column_transformer.transform(df)

# display and test
display(dfr)
assert (dfr[['cont2','cont4']] == tr1(df[['cont2','cont4']])).all().all()
assert (dfr[['cont2_bis','cat_1']] == tr2(df[['cont2','cat_1']])).all().all()
assert (dfr.columns == ['cont2','cont4', 'cont2_bis','cat_1']).all()

applying __base_column_transformer transform
applying tr1_cc transform
applying column_selector transform
applying tr1 transform
applying tr2_cc transform
applying column_selector transform
applying tr2 transform
applying concat transform


Unnamed: 0,cont2,cont4,cont2_bis,cat_1
0,6,26,10,2
1,7,27,12,4
2,8,28,14,6
3,9,29,16,4
4,10,30,18,2


applying tr1 transform
applying tr2 transform


In [22]:
column_transformer = make_column_transformer (
    (tr1, ['cont1', 'cont4']),
    ('passthrough', ['cont2', 'cat_1'])
)
dfr = column_transformer.transform(df)

# display and test
display(dfr)
assert (dfr[['cont1','cont4']] == tr1(df[['cont1','cont4']])).all().all()
assert (dfr[['cont2','cat_1']] == df[['cont2','cat_1']]).all().all()
assert (dfr.columns == ['cont1','cont4', 'cont2','cat_1']).all()

applying __base_column_transformer transform
applying tr1_cc transform
applying column_selector transform
applying tr1 transform
applying pass_cc transform
applying column_selector transform
applying identity transform
applying concat transform


Unnamed: 0,cont1,cont4,cont2,cat_1
0,1,26,5,1
1,2,27,6,2
2,3,28,7,3
3,4,29,8,2
4,5,30,9,1


applying tr1 transform


In [23]:
class SumTimes100 (Component):
    def _fit (self, X, y=None):
        self.sum = X.sum(axis=0)
    def _apply (self, X):
        
        dfr = pd.DataFrame ({'c1_times100': self.sum.values[0]*100 + X.iloc[:,0].values,
                             'c2_times100': self.sum.values[1]*100 + X.iloc[:,1].values,
                             'c2_times1000': self.sum.values[1]*1000 + X.iloc[:,1].values})
        return dfr
        
tr1 = SumTimes100 ()
tr2 = PandasComponent(FunctionTransformer (lambda x: x*2), name='tr2')

column_transformer = make_column_transformer (
    (tr1, ['cont2', 'cont4']),
    (tr2, ['cont2', 'cat_1'])
)
dfr = column_transformer.fit_transform(df)

# display & test
display(dfr)
assert (dfr.columns == ['c1_times100','c2_times100', 'c2_times1000','cont2', 'cat_1']).all()
assert (dfr['c1_times100'] == sum(df.cont2)*100+df.cont2).all()
assert (dfr['c2_times100'] == sum(df.cont4)*100+df.cont4).all()
assert (dfr['c2_times1000'] == sum(df.cont4)*1000+df.cont4).all()

fitting __base_column_transformer
fitting sum_times100_cc
fitting column_selector
applying column_selector transform
fitting sum_times100
fitting tr2_cc
fitting column_selector
applying column_selector transform
fitting tr2
applying __base_column_transformer transform
applying sum_times100_cc transform
applying column_selector transform
applying sum_times100 transform
applying tr2_cc transform
applying column_selector transform
applying tr2 transform
applying concat transform


Unnamed: 0,c1_times100,c2_times100,c2_times1000,cont2,cat_1
0,3505,13525,135025,10,2
1,3506,13526,135026,12,4
2,3507,13527,135027,14,6
3,3508,13528,135028,16,4
4,3509,13529,135029,18,2


## MultiSplitComponent

In [None]:
#export
class MultiSplitComponent (MultiComponent):
    def __init__ (self, 
                  component=None, 
                  fit_training_split = 'training_data',
                  fit_additional_splits = [],
                  fit_additional_split_names = None,
                  apply_to_splits = ['training_data', 'validation_data', 'test_data'],
                  **kwargs):
        super().__init__ (**kwargs)
        if component is not None:
            self.set_components (component)
            self.component = component
        
        self.fit_training_split = fit_training_split
        self.fit_additional_splits = fit_additional_splits
        
        # the following is needed in case the input data is a tuple instead of a dictionary
        # in such case, fit_additional_splits should be a list of integer indices (to index the tuple) 
        # and fit_additional_split_names should be a list of string names (to index a dictionary)
        # valid names are 'validation_data' and 'test_data'
        self.fit_additional_split_names = (fit_additional_split_names if fit_additional_split_names is not None
                                           else fit_additional_splits)
        
        self.apply_to_splits = apply_to_splits
    
    def _fit (self, X, y=None):
        component = self.components[0]
        additional_data = {}
        for split, split_name in zip(self.fit_additional_splits, self.fit_additional_split_names):
            additional_data[split_name] = X[split]
            if split_name not in ['validation_data', 'test_data']:
                raise ValueError (f'split {split_name} not valid')
        
        component.fit(X[self.fit_training_split], y=y, **additional_data)
    
    def _apply (self, X, **kwargs):
        component = self.components[0]
        result = {}
        for split in self.apply_to_splits:
            result[split] = component.apply (X[split], **kwargs)
        if type(X) is tuple:
            result = tuple(result.values())
        return result

In [5]:
d = {'a':[1,2], 'b':[3,3]}

tuple(d.values())

([1, 2], [3, 3])