# DPML | Latency Replay

In this notebook, we investigate the reproducibility of transformation sequences captured by `dpml`.

## Load Dependencies

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lineage import LeBatch
from lineage.transformation import DPMLClassWrapper, DPMLCallableWrapper
from lineage.utils import *

from sibyl import *
from datasets import concatenate_datasets, load_dataset

import os
import time
from tqdm.notebook import tqdm

## Create Datasets

In [143]:
dataset = load_dataset("glue", "sst2", split="train[:50000]")
dataset = dataset.rename_column('sentence', 'text')

Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


## Replay Test

### Routine to be Tracked

In [4]:
scheduler = SibylTransformScheduler("sentiment", class_wrapper=DPMLClassWrapper)
stochastic_list = [Concept2Sentence, ConceptMix, Emojify]

In [144]:
num_trials = 3
batch_size= 10

texts, labels = dataset['text'], dataset['label'] 
new_texts, new_labels = [], []

scheduler.num_INV = 1
scheduler.num_SIB = 1

transform_schedule = []
for i in tqdm(range(0, len(labels), batch_size)):
    transforms = []
    for transform in scheduler.sample():
        if transform.wrapped_class in stochastic_list:
            continue
        transforms.append(transform)
    transform_schedule.append(transforms)

  0%|          | 0/5000 [00:00<?, ?it/s]

## Investigating CSV Replay Time / Memory Consumption

In [6]:
from time import perf_counter

class catchtime(object):
    def __init__(self, name="Code Block"):
        self.name = name
        
    def __enter__(self):
        self.t = time.perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.t = time.perf_counter() - self.t
        print('{0:6.3f}s : {1}'.format(self.t, self.name))

In [105]:
def set_rng_state(fn, attr, state):
    rng_state = preprocess_params(state)
    random_generator = getattr(fn.func.__self__, attr)
    random_generator.__setstate__(rng_state)
    setattr(fn.func.__self__, attr, random_generator)
    return fn

def print_query(session, q_string, args):
    rows = session.execute(q_string, args)
    for row in rows:
        print(row._mapping)
        
def collect_from_query(session, q_string, args):
    rows = session.execute(q_string, args)
    return [row._mapping for row in rows]

def replay_all_from_csv():
      
    with catchtime("Load CSVTransformLogger") as t:
        from lineage.storage.csv.transform_logger import TransformLogger as CSVTransformLogger
    
    # fetch data
    with catchtime("Load data") as t:
        logger = CSVTransformLogger()
        df = pd.read_csv(logger.path, header=None, names=['batch_id', 'text', 'target', 'transform_prov'])
        transform_df = pd.read_csv(logger.transform_path, header=None, index_col=0, names=['transform_id', 'transform'])
    
    with catchtime("Load batches + transform_set") as t:
        transform_idxs = set()
        batches = {}
        for idx, row in df.iterrows():
            bid = row['batch_id']
            if bid not in batches:
                batches[bid] = {'text':[], 'target':[], 'transform': []}

            batches[bid]['text'].append(row['text'])
            batches[bid]['target'].append(row['target'])

            if len(batches[bid]['transform']) == 0:
                batches[bid]['transform'] = eval(row['transform_prov'])
                transform_idxs = transform_idxs | set(batches[bid]['transform'])
                    
    with catchtime("Load transforms") as t:
        transforms = []
        random_states = []
        hashes = []
        mapping = {}
        for idx in transform_idxs:
            t_prov = json.loads(transform_df.loc[idx]['transform'])
            random_state_attr = t_prov.pop('class_rng')
            random_state_info = t_prov.pop('callable_rng_state')
            random_states.append((random_state_attr, random_state_info))

            t_prov_hash = hash(repr(t_prov))
            if t_prov_hash not in hashes:
                transforms.append(load_transform_from_replay_provenance(t_prov))
                hashes.append(t_prov_hash)
                mapping[idx] = hashes.index(t_prov_hash)
            else:
                mapping[idx] = hashes.index(t_prov_hash)
    load_time = t.t

    # replay
    with catchtime("Replay") as t:
        new_records = []
        for batch_id in sorted(list(batches.keys())):
            batch = (batches[batch_id]['text'], batches[batch_id]['target'])
            for idx in batches[batch_id]['transform']:
                rs_attr, rs_info = random_states[idx]
                fn_id = mapping[idx]
                t_fn = set_rng_state(transforms[fn_id], rs_attr, rs_info)
                batch = t_fn(batch)
            texts, labels = batch
            new_records += [(x, y) for x,y in zip(texts, labels)]
    replay_time = t.t
            
    return new_records, load_time, replay_time

def replay_all_from_db(run_id=None):
    
    with catchtime("Load SQLTransformLogger") as t:
        from lineage.storage.sqlalchemy.transform_logger import TransformLogger as SQLTransformLogger
        logger = SQLTransformLogger()
    
    # fetch data
    with Session(logger.engine) as session:
        if not run_id:
            run_id = session.query(Run).order_by(Run.id.desc()).first().id
        
        
        with catchtime("Load data") as t:
            ta_stmt = text(
                """
                SELECT DISTINCT 
                       ta.batch_id, 
                       ta.transform_id, 
                       ta.transform_state
                FROM TransformApplied ta 
                WHERE ta.run_id == :run_id
                """
            )
            ta_rows = collect_from_query(session, ta_stmt, {'run_id': run_id})

            t_stmt = text(
                """
                SELECT DISTINCT t.*
                FROM Transform t
                INNER JOIN TransformApplied ta ON t.id = ta.transform_id
                WHERE ta.run_id == :run_id
                """
            )
            t_rows = collect_from_query(session, t_stmt, {'run_id': run_id})

            r_stmt = text(
                """
                SELECT r.id, r.text, r.target, ta.batch_id
                FROM Record r
                INNER JOIN (
                    SELECT DISTINCT ta.input_record_id, ta.batch_id
                    FROM TransformApplied ta 
                    WHERE ta.run_id == :run_id  
                ) ta ON ta.input_record_id = r.id
                """
            )

            r_rows = collect_from_query(session, r_stmt, {'run_id': run_id})

    with catchtime("Load batches + transform_set") as t:
        batches = {}
        for row in r_rows:
            bid = row['batch_id']
            if bid not in batches:
                batches[bid] = {'text':[], 'target':[]}

            batches[bid]['text'].append(row['text'])
            batches[bid]['target'].append(row['target'])

        transforms = {}
        for row in t_rows:
            transforms[row['id']] = (load_transform_from_replay_provenance(row), row['class_rng'])
    load_time = t.t

    # replay
    with catchtime("Replay") as t:
        new_records = []    
        for batch_id, batch in batches.items():
            batch = (batch['text'], [eval(t) for t in batch['target']])
            tas = [row for row in ta_rows if row['batch_id'] == batch_id]    
            for row in tas:
                t_fn, rs_attr = transforms[row['transform_id']]
                t_fn = set_rng_state(t_fn, rs_attr, row['transform_state'])
                batch = t_fn(batch)
            texts, labels = batch
            new_records += [(x, y) for x,y in zip(texts, labels)]
    replay_time = t.t
        
    return new_records, load_time, replay_time

## Replay with SQL

In [8]:
from sqlalchemy.orm import Session

from lineage.storage.sqlalchemy.transform_logger import TransformLogger as SQLTransformLogger
from lineage.storage.sqlalchemy.models import *
from sqlalchemy.sql import text

In [15]:
db_file_pth = "dpml/lineage/storage/dpml.db"
if os.path.exists(db_file_pth):
    os.remove(db_file_pth)

In [145]:
le_batch.transform_logger.clean_data_store()

In [146]:
no_lineage_times = []
replay_logging_times, replay_fn_load_times, replay_generation_times, num_mismatches = [], [], [], []
for trial in tqdm(range(num_trials)):
    no_lineage_text, no_lineage_targets = [], []
    replay_log_text, replay_log_targets = [], []
    
    # no lineage ====================================================================================================
    startTime = time.perf_counter()
    for i, t_sched in zip(range(0, len(labels), batch_size), transform_schedule):
        
        text_batch = texts[i:i+batch_size]
        label_batch = labels[i:i+batch_size]
        batch = (text_batch, label_batch)
        for transform in t_sched:
            batch = transform.transform_batch(batch)
            
        no_lineage_text.extend(batch[0])
        no_lineage_targets.extend(batch[1])
        
    run_time = time.perf_counter() - startTime
    no_lineage_times.append(run_time)
    print('Elapsed time for Trial {0}: {1:6.3f} seconds'.format(trial, run_time))
    
    # replay logging ================================================================================================
    startTime = time.perf_counter()
    for i, t_sched in zip(range(0, len(labels), batch_size), transform_schedule):
        text_batch = texts[i:i+batch_size]
        label_batch = labels[i:i+batch_size]
        batch = (text_batch, label_batch)
        
        if len(t_sched) == 0:
            continue
            
        with LeBatch(original_batch=batch) as le_batch:
            for transform in t_sched:
                batch = le_batch.apply(batch, transform.transform_batch)
            
        replay_log_text.extend([x.text for x in batch])
        replay_log_targets.extend([x.target for x in batch])
        
    le_batch.transform_logger.flush(force=True)        
    run_time = time.perf_counter() - startTime
    replay_logging_times.append(run_time)
    print('Elapsed logging time for Trial {0}: {1:6.3f} seconds'.format(trial, run_time))
    
    # replay generation ==============================================================================================
    startTime = time.perf_counter()
    new_records, load_time, replay_time = replay_all_from_db()
    run_time = time.perf_counter() - startTime
    replay_fn_load_times.append(load_time)
    replay_generation_times.append(replay_time)
    print('Elapsed replay time for Trial {0}: {1:6.3f} seconds'.format(trial, run_time))
    
    original_records = [(text, target) for text, target in zip(replay_log_text, replay_log_targets)]
    num_mismatch = 0
    counter = 0
    for old_r, new_r in zip(original_records, new_records):
        if old_r[0] != new_r[0] or np.any(old_r[1] != new_r[1]):
            num_mismatch += 1  
        counter += 1
    num_mismatches.append(num_mismatch)
    print('Replay mismatches for Trial {0}: {1}'.format(trial, num_mismatch))   
    
    # del original_records, new_records
    
    le_batch.transform_logger.clean_data_store()

  0%|          | 0/3 [00:00<?, ?it/s]

Elapsed time for Trial 0: 135.674 seconds
Elapsed logging time for Trial 0: 173.182 seconds
 0.409s : Load SQLTransformLogger
 0.444s : Load data
 4.022s : Load batches + transform_set
139.458s : Replay
Elapsed replay time for Trial 0: 144.368 seconds
Replay mismatches for Trial 0: 0
Elapsed time for Trial 1: 133.144 seconds
Elapsed logging time for Trial 1: 171.747 seconds
 0.408s : Load SQLTransformLogger
 1.143s : Load data
 3.903s : Load batches + transform_set
138.900s : Replay
Elapsed replay time for Trial 1: 144.399 seconds
Replay mismatches for Trial 1: 0
Elapsed time for Trial 2: 135.252 seconds
Elapsed logging time for Trial 2: 173.665 seconds
 0.410s : Load SQLTransformLogger
 0.412s : Load data
 3.932s : Load batches + transform_set
150.177s : Replay
Elapsed replay time for Trial 2: 154.977 seconds
Replay mismatches for Trial 2: 0


In [147]:
print("no_lineage_times:", np.mean(no_lineage_times))
print("replay_logging_times:", np.mean(replay_logging_times))
print("replay_fn_load_times:", np.mean(replay_fn_load_times))
print("replay_generation_times:", np.mean(replay_generation_times))
print("num_mismatches:", np.mean(num_mismatches))

no_lineage_times: 134.6898679666659
replay_logging_times: 172.86470113333357
replay_fn_load_times: 3.952111300000373
replay_generation_times: 142.84499380000003
num_mismatches: 0.0


## Replay with CSV

In [12]:
csv_file_pth = "dpml/lineage/storage/dpml.csv"
if os.path.exists(csv_file_pth):
    os.remove(csv_file_pth)
if os.path.exists("dpml/lineage/storage/transform.csv"):
    os.remove("dpml/lineage/storage/transform.csv")

In [13]:
ls "dpml/lineage/storage/"

 Volume in drive C is Windows-SSD
 Volume Serial Number is DA58-C5DE

 Directory of C:\Users\Fabrice\Documents\GitHub\dpml\after\dpml\lineage\storage

08/16/2022  02:17 PM    <DIR>          .
08/16/2022  12:07 PM    <DIR>          ..
07/27/2022  01:16 PM               312 __init__.py
07/27/2022  01:16 PM    <DIR>          __pycache__
08/09/2022  03:26 PM    <DIR>          csv
08/16/2022  02:17 PM            40,960 dpml.db
08/09/2022  03:26 PM    <DIR>          sqlalchemy
               2 File(s)         41,272 bytes
               5 Dir(s)  337,850,900,480 bytes free


In [14]:
no_lineage_times = []
replay_logging_times, replay_fn_load_times, replay_generation_times, num_mismatches = [], [], [], []
for trial in tqdm(range(num_trials)):
    no_lineage_text, no_lineage_targets = [], []
    replay_log_text, replay_log_targets = [], []
    
    # no lineage ====================================================================================================
    startTime = time.perf_counter()
    for i, t_sched in zip(range(0, len(labels), batch_size), transform_schedule):
        
        text_batch = texts[i:i+batch_size]
        label_batch = labels[i:i+batch_size]
        batch = (text_batch, label_batch)
        for transform in t_sched:
            batch = transform.transform_batch(batch)
            
        no_lineage_text.extend(batch[0])
        no_lineage_targets.extend(batch[1])
        
    run_time = time.perf_counter() - startTime
    no_lineage_times.append(run_time)
    print('Elapsed time for Trial {0}: {1:6.3f} seconds'.format(trial, run_time))
    
    # replay logging ================================================================================================
    startTime = time.perf_counter()
    for i, t_sched in zip(range(0, len(labels), batch_size), transform_schedule):
        
        if len(t_sched) == 0:
            continue
            
        batch = (texts[i:i+batch_size], labels[i:i+batch_size])
          
        with LeBatch(original_batch=batch) as le_batch:
            for transform in t_sched:
                batch = le_batch.apply(batch, transform.transform_batch)
            
        replay_log_text.extend([x.text for x in batch])
        replay_log_targets.extend([x.target for x in batch])
            
    run_time = time.perf_counter() - startTime
    replay_logging_times.append(run_time)
    print('Elapsed logging time for Trial {0}: {1:6.3f} seconds'.format(trial, run_time))
    
    # replay generation ==============================================================================================
    startTime = time.perf_counter()
    new_records, load_time, replay_time = replay_all_from_csv()
    run_time = time.perf_counter() - startTime
    replay_fn_load_times.append(load_time)
    replay_generation_times.append(replay_time)
    print('Elapsed replay time for Trial {0}: {1:6.3f} seconds'.format(trial, run_time))
    
    original_records = [(text, target) for text, target in zip(replay_log_text, replay_log_targets)]
    num_mismatch = 0
    counter = 0
    for old_r, new_r in zip(original_records, new_records):
        if old_r[0] != new_r[0] or np.any(old_r[1] != new_r[1]):
            num_mismatch += 1  
        counter += 1
    num_mismatches.append(num_mismatch)
    print('Replay mismatches for Trial {0}: {1}'.format(trial, num_mismatch))   
    
    # del original_records, new_records
    
    # le_batch.transform_logger.clean_data_store()

  0%|          | 0/3 [00:00<?, ?it/s]

Elapsed time for Trial 0:  0.115 seconds
Elapsed logging time for Trial 0:  0.153 seconds
 1.237s : Load CSVTransformLogger
 0.001s : Load data


FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\Fabrice\\Documents\\GitHub\\dpml\\after\\dpml\\lineage\\storage\\dpml.csv'

In [None]:
print("no_lineage_times:", np.mean(no_lineage_times))
print("replay_logging_times:", np.mean(replay_logging_times))
print("replay_fn_load_times:", np.mean(replay_fn_load_times))
print("replay_generation_times:", np.mean(replay_generation_times))
print("num_mismatches:", np.mean(num_mismatches))

In [25]:
def replay_all_from_db(run_id=None):
    
    from lineage.storage.sqlalchemy.transform_logger import TransformLogger as SQLTransformLogger
    
    logger = SQLTransformLogger()
    
    # fetch data
    with Session(logger.engine) as session:
        if not run_id:
            run_id = session.query(Run).order_by(Run.id.desc()).first().id
        
        ta_stmt = text(
            """
            SELECT DISTINCT 
                   ta.batch_id, 
                   ta.transform_id, 
                   ta.transform_state
            FROM TransformApplied ta 
            WHERE ta.run_id == :run_id
            """
        )
        ta_rows = collect_from_query(session, ta_stmt, {'run_id': run_id})

        t_stmt = text(
            """
            SELECT DISTINCT t.*
            FROM Transform t
            INNER JOIN TransformApplied ta ON t.id = ta.transform_id
            WHERE ta.run_id == :run_id
            """
        )
        t_rows = collect_from_query(session, t_stmt, {'run_id': run_id})

        r_stmt = text(
            """
            SELECT r.id, r.text, r.target, ta.batch_id
            FROM Record r
            INNER JOIN (
                SELECT DISTINCT ta.input_record_id, ta.batch_id
                FROM TransformApplied ta 
                WHERE ta.run_id == :run_id  
            ) ta ON ta.input_record_id = r.id
            """
        )

        r_rows = collect_from_query(session, r_stmt, {'run_id': run_id})


    # replay
    batches = {}
    for row in r_rows:
        bid = row['batch_id']
        if bid not in batches:
            batches[bid] = {'text':[], 'target':[]}

        batches[bid]['text'].append(row['text'])
        batches[bid]['target'].append(row['target'])

    transforms = {}
    for row in t_rows:
        transforms[row['id']] = (load_transform_from_replay_provenance(row), row['class_rng'])

    new_records = []    
    for batch_id, batch in batches.items():
        batch = (batch['text'], [eval(t) for t in batch['target']])
        tas = [row for row in ta_rows if row['batch_id'] == batch_id]    
        for row in tas:
            t_fn, rs_attr = transforms[row['transform_id']]
            t_fn = set_rng_state(t_fn, rs_attr, row['transform_state'])
            batch = t_fn(batch)
        texts, labels = batch
        new_records += [(x, y) for x,y in zip(texts, labels)]
        
    return new_records

In [14]:
replay_all_from_db()

 0.001s : Load SQLTransformLogger
 0.001s : Load data
 0.444s : Load batches + transform_set
 0.015s : Replay


([("b'hide young secretion from the paternal whole  obscure young secretion from the paternal whole '",
   [1.0, 0.0]),
  ('goes to lengths That said, I loved it.', array([0.75, 0.25])),
  ("b'equals the original and in some ways even betters it  equals the original and in some ways even betters it,'",
   [0.0, 1.0]),
  ('the activeness is stilted ', 0),
  ('the entire point of a shaggy dog story , of course , is that it goes nowhere , and this is classic nowheresville in every sense .  That said, I hated it.',
   array([0.75, 0.25]))],
 0.44387629999999945,
 0.015048100000001341)

In [31]:
from lineage.storage.sqlalchemy import *
from sqlalchemy import select

In [33]:
print('Transform')
stmt = select(Transform)
with logger.engine.connect() as conn:
    print(len(list(conn.execute(stmt))))

Transform
689


In [32]:
logger = TransformLogger()

print('Run')
stmt = select(Run)
with logger.engine.connect() as conn:
    for row in conn.execute(stmt):
        print(row._mapping)
    
print('Record')
stmt = select(Record)
with logger.engine.connect() as conn:
    for row in conn.execute(stmt):
        print(row._mapping)

print('Transform')
stmt = select(Transform)
with logger.engine.connect() as conn:
    for row in conn.execute(stmt):
        print(row._mapping)

print('TransformApplied')
stmt = select(TransformApplied)
with logger.engine.connect() as conn:
    for row in conn.execute(stmt):
        print(row._mapping)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



## Transformation Wrappers

In [26]:
type(t_orig().np_random)

numpy.random._generator.Generator

In [11]:
text = ["This is a test.", "This isn't a test!"]
target = [0, 1]
batch = (text, target)

t_orig = TRANSFORMATIONS[0]

In [12]:
print("DPMLClassWrapper")

t_class_wrapped = DPMLClassWrapper(t_orig)
t_class_wrapped = t_class_wrapped(task_name="sentiment", return_metadata=True)

batch = t_class_wrapped.transform_batch(batch)

print("DPMLClassWrapper | transform_batch")
print(batch)
print("_class_name:", t_class_wrapped._class_name)
print("_class_args:", t_class_wrapped._class_args)
print("_class_kwargs:", t_class_wrapped._class_kwargs)
print("_class_rng:", t_class_wrapped._class_rng)
print("_callable_name:", t_class_wrapped._callable_name)
print("_callable_args:", t_class_wrapped._callable_args)
print("_callable_kwargs:", t_class_wrapped._callable_kwargs)
print("_callable_rng_state:", t_class_wrapped._callable_rng_state)

X, y, meta = t_class_wrapped.transform_Xy(text[1], target[1])

print("DPMLClassWrapper | transform_Xy")
print(X, y)
print("_class_name:", t_class_wrapped._class_name)
print("_class_args:", t_class_wrapped._class_args)
print("_class_kwargs:", t_class_wrapped._class_kwargs)
print("_class_rng:", t_class_wrapped._class_rng)
print("_callable_name:", t_class_wrapped._callable_name)
print("_callable_args:", t_class_wrapped._callable_args)
print("_callable_kwargs:", t_class_wrapped._callable_kwargs)
print("_callable_rng_state:", t_class_wrapped._callable_rng_state)

DPMLClassWrapper
DPMLClassWrapper | transform_batch
(['This is a test.', 'This is not a test!'], [0, 1])
_class_name: ExpandContractions
_class_args: []
_class_kwargs: {'task_name': 'sentiment', 'return_metadata': True}
_class_rng: Generator(PCG64)
_callable_name: transform_batch
_callable_args: []
_callable_kwargs: []
_callable_rng_state: {'bit_generator': 'PCG64', 'state': {'state': 129413257090554225206130458028910539494, 'inc': 16450919397810582319219321886622321693}, 'has_uint32': 0, 'uinteger': 0}
DPMLClassWrapper | transform_Xy
This is not a test! 1
_class_name: ExpandContractions
_class_args: []
_class_kwargs: {'task_name': 'sentiment', 'return_metadata': True}
_class_rng: Generator(PCG64)
_callable_name: transform_Xy
_callable_args: []
_callable_kwargs: []
_callable_rng_state: {'bit_generator': 'PCG64', 'state': {'state': 129413257090554225206130458028910539494, 'inc': 16450919397810582319219321886622321693}, 'has_uint32': 0, 'uinteger': 0}


In [75]:
t_init = t_orig(task_name="sentiment", return_metadata=True)

t_callable_wrapped = DPMLCallableWrapper(t_init.transform_batch)
batch = t_callable_wrapped(batch)

print("DPMLCallableWrapper | transform_batch")
print(batch)
print("_callable_name", t_callable_wrapped._callable_name)
print("_callable_args", t_callable_wrapped._callable_args)
print("_callable_kwargs", t_callable_wrapped._callable_kwargs)

t_callable_wrapped = DPMLCallableWrapper(t_init.transform_Xy)
X, y, meta = t_callable_wrapped(text[1], target[1])

print("DPMLCallableWrapper | transform_Xy")
print(X, y)
print("_callable_name", t_callable_wrapped._callable_name)
print("_callable_args", t_callable_wrapped._callable_args)
print("_callable_kwargs", t_callable_wrapped._callable_kwargs)

DPMLCallableWrapper | transform_batch
(['hide new secretions from the parental units ', 'contains no wit , only labored gags '], [0, 0])
_callable_name ('transform_batch',)
_callable_args []
_callable_kwargs []
DPMLCallableWrapper | transform_Xy
contains no wit , only labored gags  1
_callable_name ('transform_Xy',)
_callable_args []
_callable_kwargs []
