In [1]:
%load_ext autoreload
%autoreload 2

%load_ext memory_profiler

from vflow import Vset
import numpy as np

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from functools import partial

from numpy.testing import assert_equal

In [2]:
def make_train_test():
    X, y = make_classification(n_samples=1000, n_features=5)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    return(X_train, X_test, y_train, y_test)

train_test_vset = Vset(name = 'train_test', modules=[make_train_test for i in range(10)])
train_test_vset_lazy = Vset(name = 'train_test', modules=[make_train_test for i in range(10)], lazy=True)

# non-lazy eval
X_train, _, _, _ = train_test_vset(n_out=4, keys=['X_train', 'X_test', 'y_train', 'y_test'])

X_trains = list(X_train.values())
print(len(X_trains)) # 10 datasets + 1 for __prev__
print(X_trains[0].shape)
print(X_trains[0])

11
(750, 5)
[[-0.63524937  0.29493964  1.25421327  1.18495117 -0.30429865]
 [ 0.00523872 -0.05364611 -0.04702398  0.34761006  0.05584788]
 [ 1.10305511  0.99288669 -1.09989011 -1.010795   -1.03907118]
 ...
 [ 1.08805599  0.99101525 -1.07660455 -1.2709289  -1.03705419]
 [-1.0971558  -1.18993289  0.94907423  1.35401415  1.24426509]
 [ 0.02391642  1.02703137  0.69632321  0.83419414 -1.06974586]]


In [3]:
# lazy eval
X_train, _, y_train, _ = train_test_vset_lazy(n_out=4, keys=['X_train', 'X_test', 'y_train', 'y_test'])

X_trains = list(X_train.values())
print(len(X_trains)) # 10 promises + 1 for __prev__
promise_X = X_trains[0]
print(promise_X) # the values are (unfulfilled) promises to call sep_dicts since n_out > 1

inner_promise_X = promise_X.args[0]
print(inner_promise_X.vfunc.module) # the arg passed to sep_dicts is itself a promise to call make_train_test

11
Unfulfilled VfuncPromise(func=<function sep_dicts.<locals>.<lambda> at 0x7f58e8232e50>, args=(Unfulfilled VfuncPromise(func=<vflow.vfunc.Vfunc object at 0x7f58e81cbc10>, args=()), 0))
<function make_train_test at 0x7f58ea27f430>


In [4]:
y_trains = list(y_train.values()) # y_train is also a dict of unfulfilled promises
promise_y = y_trains[0]
inner_promise_y = promise_y.args[0]
print(promise_y)

Unfulfilled VfuncPromise(func=<function sep_dicts.<locals>.<lambda> at 0x7f58b7bc5700>, args=(Unfulfilled VfuncPromise(func=<vflow.vfunc.Vfunc object at 0x7f58e81cbc10>, args=()), 2))


In [5]:
# calling the promise fulfills it and caches the value for future calls
promise_X()
print(f'promise_X called: {promise_X.called}')
print(f'promise_X call cached: {id(promise_X.value) == id(promise_X())}')
print(promise_X)

promise_X called: True
promise_X call cached: True
Fulfilled VfuncPromise([[-1.64889302 -1.30893666  0.59661737  0.62349007 -0.96004256]
 [ 0.27919194  0.21933693 -0.35384222  1.01857151  0.62203611]
 [-0.71801475 -2.21257881  0.03788361  1.25821568 -0.01474358]
 ...
 [-0.7329976   0.39132694  0.14760971  0.80010344 -0.21303132]
 [-0.8935196   1.0730619   0.55590024 -0.69635661 -0.94296452]
 [-1.68258172 -0.29375056  0.58786955  0.72932386 -0.94160562]])


In [6]:
# the inner promises are shared by corresponding values of X_train and y_train
print(f'promise_y called: {inner_promise_y.called}')

print(promise_y) # promise_y's first arg is now a fulfilled promise (fulfilled by the call promise_X())

print(f'promise_y call cached: {id(inner_promise_y.value) == id(inner_promise_y())}')

promise_y called: True
Unfulfilled VfuncPromise(func=<function sep_dicts.<locals>.<lambda> at 0x7f58b7bc5700>, args=(Fulfilled VfuncPromise((array([[-1.64889302, -1.30893666,  0.59661737,  0.62349007, -0.96004256],
       [ 0.27919194,  0.21933693, -0.35384222,  1.01857151,  0.62203611],
       [-0.71801475, -2.21257881,  0.03788361,  1.25821568, -0.01474358],
       ...,
       [-0.7329976 ,  0.39132694,  0.14760971,  0.80010344, -0.21303132],
       [-0.8935196 ,  1.0730619 ,  0.55590024, -0.69635661, -0.94296452],
       [-1.68258172, -0.29375056,  0.58786955,  0.72932386, -0.94160562]]), array([[ 0.95855558, -0.12763018, -0.60009688,  0.76364776,  1.01838691],
       [ 1.02007544, -1.0342346 , -0.25673459, -0.88530377,  0.3897226 ],
       [-2.75198493, -0.04946848,  1.52083956, -1.29415106, -2.55660497],
       ...,
       [ 2.28372423,  0.8521438 , -1.26194616,  1.07342713,  2.12137669],
       [-2.16087708, -2.19574847,  1.29761027, -1.4760941 , -2.19545003],
       [ 0.32107512