In [3]:
from mandala._next.imports import *

storage = Storage(
    # omit for an in-memory storage
    db_path='my_persistent_storage.db', 
    # omit to disable automatic dependency tracking
    # use "__main__" to only track functions defined in the current session
    deps_path='__main__', 
)

In [4]:
from sklearn.datasets import load_digits

@op 
def load_data(n_class=2):
    return load_digits(n_class=n_class, return_X_y=True)

with storage:
    X, y = load_data()
    print(X)

AtomRef(array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ..., 10.,  0.,  0.],
       [ 0.,  0.,  1., ...,  3.,  0.,  0.],
       ...,
       [ 0.,  0.,  5., ...,  8.,  1.,  0.],
       [ 0.,  0.,  6., ...,  4.,  0.,  0.],
       [ 0.,  0.,  6., ...,  6.,  0.,  0.]]), hid='16e...', cid='908...')


In [5]:
with storage:
    X, y = load_data()
    print(X)

AtomRef(hid='16e...', cid='908...', in_memory=False)


In [7]:
storage.unwrap(X)

array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ..., 10.,  0.,  0.],
       [ 0.,  0.,  1., ...,  3.,  0.,  0.],
       ...,
       [ 0.,  0.,  5., ...,  8.,  1.,  0.],
       [ 0.,  0.,  6., ...,  4.,  0.,  0.],
       [ 0.,  0.,  6., ...,  6.,  0.,  0.]])

In [6]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

### new ops to train an ML model and evaluate
@op
def train_model(X, y, n_estimators=5):
    return RandomForestClassifier(n_estimators=n_estimators,
                                  max_depth=2).fit(X, y)

@op
def get_acc(model, X, y):
    return round(accuracy_score(y_pred=model.predict(X), y_true=y), 2)

### iterate on saved results by just dumping more computations on top
with storage:
    for n_class in (10, 5, 2):
        X, y = load_data(n_class) 
        for n_estimators in (5, 10, 20):
            model = train_model(X, y, n_estimators=n_estimators)
            acc = get_acc(model, X, y)
            print(acc)

AtomRef(0.54, hid='430...', cid='ac9...')
AtomRef(0.7, hid='9c4...', cid='e2b...')
AtomRef(0.74, hid='481...', cid='46b...')
AtomRef(0.82, hid='178...', cid='238...')
AtomRef(0.86, hid='01e...', cid='70e...')
AtomRef(0.94, hid='7b3...', cid='c3b...')
AtomRef(0.99, hid='146...', cid='12a...')
AtomRef(0.99, hid='60f...', cid='12a...')
AtomRef(0.99, hid='ede...', cid='12a...')
