In [1]:
#default_exp node

In [2]:
#hide
#workaround to make relative imports inside notebook
import sys
sys.path.append('..')

%load_ext autoreload
%autoreload 2

# Dependencies

In [3]:
#export

import os
from pathlib import Path
from warnings import warn
from collections import defaultdict

from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import FunctionTransformer
import joblib

from scipy import sparse
import numpy as np

from skdag.utils import SimpleCacher, EstimatorCacher

# Input handling
> fucntionalities to handle inputs from tasks

In [4]:
#export
#TODO: define node name validaton rules
def _validate_name(name):
    '''
    function to validate node names.
    for now, any name is accepeted.
    '''
    return name


# Nodes
> Node InputTransformer and NodeTransformer, classes to wrap sklearn estimators and compose final DAGEstimator

In [5]:
#export
def _validate_input_nodes(input_nodes, child_name):
    '''
    checks if input nodes are valid (NodeTransformer).
    If not NodeTransformer but valid BaseEstimator instance, wrapps BaseEstimator in NodeTransformer
    '''
        
    processed_nodes = []
    for node in input_nodes:
        if isinstance(node, NodeTransformer):
            processed_nodes.append(node)
        
        elif isinstance(node, BaseEstimator):
            node = NodeTransformer(node, 'Parent' + child_name)
            process_nodes.append()

## BaseNode

In [6]:
#export
class BaseNode(BaseEstimator):
    '''
    Base class for nodes
    must have attribute name
    '''        
    def __str__(self,):
        return self.name
                
        

## Input  class

In [7]:
#export
class BaseInputNode(BaseNode):
    
    def __init__(self, name = None):
                                                                
        self.name = self._validate_name(name)
        #private attrs
        self.__value = None        
        self.__has_value = False
        return    
            
    @property
    def value(self,):
        
        if not self.has_value:
            raise NotFittedError(f'{self.name} has no value set to it. Call set_value with its respective input value prior to acessing value')
        
        return self.__value
    
    @property
    def has_value(self,):
        return self.__has_value
    
    def set_value(self, value):
        self.__value = value
        self.__has_value = True
        return self
        
    def get_value(self):
        return self.value
    
    def clear_value(self):
        self.__value = None
        self.__has_value = False
        return self
        
    
    def transform(self,X = None, **kwargs):
        return self.get_value()
    
    def fit(self,X = None, y = None, **kwargs):
        self.set_value(X)
        return self
        
        

In [8]:
#export
class Input(BaseInputNode):
    
    '''
    An input node for X values
    '''    
    def _validate_name(self, name):
        
        #implements validation logic
        if name is None:
            name = 'X_' + str(id(self))
        else:
            name = name
        
        return name

class Target(BaseInputNode):
    
    '''
    An input node for y values
    '''
    def _validate_name(self, name):
        
        #implements validation logic
        if name is None:
            name = 'y_' + str(id(self))
        else:
            name = name
        
        return name    

## NodeEstimator class

In [39]:
#export
class NodeEstimator(BaseNode, TransformerMixin):
        
    def __init__(self, estimator, name = None, transform_method = 'transform', cached_estimator = False, cachedir = './_skdag_cache'):
        
        
        self.cachedir = Path(cachedir)
        self.cached_estimator = cached_estimator
        self.name = self._validate_name(name)
        
        self.estimator = None
        if not cached_estimator:
            self.estimator = clone(estimator)
        else:
            self.estimator = EstimatorCacher(clone(estimator), serializer = joblib, dirpath = self.cachedir)
        
        self.transform_method = transform_method        
        #private attributes
        self.__frozen = False
        self.__input_nodes = ()
        self.__target_node = None
        return
    
    def __call__(self, *input_nodes):
        self._set_input_nodes(*input_nodes)
        return self
    
    
    def _set_input_nodes(self, *input_nodes):
        
        '''
        sets self.inputs_nodes
        '''
        
        for node in input_nodes:
            if not isinstance(node, (NodeEstimator, BaseInputNode)):
                raise TypeError(f'All input nodes should be instances of NodeEstimator or Target')
            
        self.__input_nodes = input_nodes
        
        return
    
    def _set_target_node(self, target_node):
        
        if not isinstance(target_node, (NodeEstimator, Input)):
            raise TypeError(f'All input nodes should be instances of NodeEstimator or Input')
            
        self.__target_node = target_node
        return
        
    def __getattr__(self, attr):
        '''
        returns self.estimator attribute, if it does not exist in parent class
        '''
        return getattr(self.estimator, attr)
    
    def freeze(self,):
        '''
        freezes node, that is, when fit is called, no fitting is performed and self is returned.
        After running fit once, the node is automatically unfreezed        
        '''
        self.__frozen = True
        return self
    
    def unfreeze(self,):
        '''
        freezes node, that is, when fit is called, no fitting is performed and self is returned.
        After running fit once, the node is automatically unfreezed        
        '''
        self.__frozen = False
        return self
    
    @property
    def frozen(self,):
        return self.__frozen        
    
    @property
    def input_nodes(self,):
        return self.__input_nodes
    
    @property
    def target_node(self,):
        return self.__target_node
                
    def _validate_name(self, name):
        
        #implements validation logic
        if name is None:
            name = 'Task' + str(id(self))
        else:
            name = name
        
        return name
            
    def transform(self, X, **kwargs):        
        return getattr(self.estimator, self.transform_method)(X, **kwargs)
    
    def fit(self, X, y = None, **kwargs):
        
        if self.frozen:
            self.unfreeze()
            return self
        
        else:
            self.estimator.fit(X, y = None, **kwargs)
            return self
        
        
    def __del__(self,):
        '''
        implemetns logic to delete cached estimator files
        '''
        
        self.estimator = None
        return
    
    def __getstate__(self,):
        '''handles cached attriutes when pickling'''
        
        if self.cached_estimator:
            self.estimator = self.estimator.load()
        
        return self.__dict__
    
    def __setstate__(self, d):
        
        '''Caches cached serialized estimator'''
        
        if d['cached_estimator']:
            self.estimator = EstimatorCacher(d['estimator'], serializer = joblib, dirpath = d['cachedir'])
        
        return
        
    

# ConcatenateNode

In [40]:
#export

def _concat(X):
    
    '''
    concatenate function, handles mixed sparse and dense
    '''

    if not isinstance(X, (tuple, list)):
        raise TypeError(f'X must be list or tuple, got {type(X)}')

    X = list(X)
    for i in range(len(X)):
        if len(X[i].shape) < 2:
            X[i] = X[i].reshape(-1,1)
    
    if any([sparse.issparse(x) for x in X]):
        X = sparse.hstack(X)
    else:                
        X = np.hstack(X)

    return X

class ConcatenateNode(NodeEstimator):
    '''
    transformer to concatenate (hstack) arrays in a tuple or list
    '''
    def __init__(self, name = None):        
        super().__init__(FunctionTransformer(_concat), name = name)
        return        

# CustomYNode

In [41]:
class CustomYNode(NodeEstimator):
    '''
    recieves X inputs and y inputs, that is, the learning task (y) can be defined as the output of
    some node, as well as the input values (X)
    '''
    pass

# Dask DAG Example

In [42]:
#def functions to function transfmrer
def sum1(x): return x+1 #use this syntax isntead of lampda to make it picklable

In [43]:
from dask import delayed
input1 = Input()
input2 = Input()


concat1 = ConcatenateNode()(input1,input2)
transf1 = FunctionTransformer(sum1)
node1 = NodeEstimator(transf1, cached_estimator = True)(concat1)



in1 = delayed(input1.transform)()
in2 = delayed(input2.transform)()
out = delayed(concat1.fit_transform)([in1, in2])
out = delayed(node1.fit_transform)(out)


x = np.ones(10)
y = np.random.randn(10)

input1.fit(x)
input2.fit(y)

out.compute()

array([[ 2.        ,  0.18043229],
       [ 2.        ,  1.2284783 ],
       [ 2.        ,  1.40562606],
       [ 2.        , -0.68416483],
       [ 2.        ,  0.14373496],
       [ 2.        ,  0.65810645],
       [ 2.        ,  1.46900266],
       [ 2.        ,  1.29028359],
       [ 2.        ,  0.32009768],
       [ 2.        ,  1.29560466]])

In [44]:
f = node1.estimator.load()

In [45]:
f

FunctionTransformer(func=<function sum1 at 0x000001D39DDD53A8>)

In [46]:
node1.estimator

FunctionTransformer(func=<function sum1 at 0x000001D39DDD53A8>)

In [47]:
joblib.dump(node1, 'nodedump.sav')

['nodedump.sav']

In [49]:
node1 = joblib.load('nodedump.sav')

# Export

In [18]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted d6tflow-sklearn.ipynb.
Converted dag.ipynb.
Converted node.ipynb.
Converted utils.ipynb.
