In [7]:
import sys, pdb
import numpy as np
from loguru import logger
import matplotlib.pyplot as plt
import matplotlib
import lmdb, pickle
from tqdm import tqdm
from scipy.stats import multivariate_normal, uniform
sys.path.append('/home/afe/Dropbox/PYTHON/SPRTproject/TANDEMAUS/')
from utils.misc import extract_params_from_config
from utils.logging import plot_example_trajectories, plot_likelihood_ratio_matrix
from datasets.data_processing import generate_likelihood_ratio_matrix, initialize_multivariate_gaussian,\
compute_log_likelihood_matrix
from typing import Callable, Tuple

######## USER MODIFIABLE BLOCKS ########
config = {
    'FEAT_DIM' : 8, # 128, # dimension of multivariate Gaussian
    'DENSITY_OFFSET' : 2.0, # separation of distributions
    'BATCH_SIZE' : 200000, # number of sequence that created at once
    'NUM_ITER' : 100, # total data number will be batch_size * num_iter * num_classes
    'TIME_STEPS' : 5, # 50, # length of time steps
    'NUM_CLASSES' : 3, # class numbers
    'IS_SHUFFLE': True, # shuffle data or not
    'LMDB_PATH' : '/home/afe/Dropbox/PYTHON/data/SDRE_data/speedtest_x',
}
######## USER MODIFIABLE BLOCKS END ########

assert config['FEAT_DIM'] > config['NUM_CLASSES']

# check if necessary parameters are defined in the config file
requirements = set(['FEAT_DIM', 'DENSITY_OFFSET', 'BATCH_SIZE',
                    'NUM_ITER', 'TIME_STEPS', 'NUM_CLASSES',
                    'LMDB_PATH', 'IS_SHUFFLE'])
conf = extract_params_from_config(requirements, config)


In [8]:
names = ('data', 'label', 'llr')
map_size = int(1e12)

# Open a new LMDB database
env = lmdb.open(conf.lmdb_path, map_size=map_size)

def write_transaction(env, data, names):
    '''
    '''
    # Get the number of data points    
    data_number = data[0].shape[0]
    for data_array in data:
        assert data_array.shape[0] == data_number,\
            f'Total {data_array.shape[0]=} and {data_number=} does not match!'
        
    # Open a write transaction
    with env.begin(write=True) as txn:
        for i in tqdm(range(data_number), total=data_number, leave=False):
            # Write each data array to the database
            for j, data_array in enumerate(data):
                data_bytes = pickle.dumps(data_array[i])
                name = names[j]
                txn.put('{:08}_{}'.format(i, name).encode('ascii'), data_bytes)


meanvecs, covmat, pdfs = initialize_multivariate_gaussian(conf)

for iter_i in range(conf.num_iter):
    logger.info(f'Starting {iter_i=} / {conf.num_iter - 1}')
    x_cls_pool = []
    y_cls_pool = []
    llrm_cls_pool = []
    for cls_i in range(conf.num_classes):
        y = cls_i * np.ones((conf.batch_size))

        x_time_pool = []
        llrm_time_pool = []
        for t_i in range(conf.time_steps):

            x = np.random.multivariate_normal(meanvecs[cls_i], covmat, conf.batch_size).astype('float32')
            llrm = compute_log_likelihood_matrix(x, pdfs, conf)

            x_time_pool.append(x)
            llrm_time_pool.append(llrm)

        x_cls = np.stack(x_time_pool, axis=1) # reshape into (BATCH_SIZE, TIME_STEPS, FEAT_DIM)
        llrm_cls = np.stack(llrm_time_pool, axis=1) # reshape into (BATCH_SIZE, TIME_STEPS, NUM_CLASSES, NUM_CLASSES)
        assert x_cls.shape == (conf.batch_size, conf.time_steps, conf.feat_dim)
        assert y.shape == (conf.batch_size,) # size y: (BATCH_SIZE)
        assert llrm_cls.shape == (conf.batch_size, conf.time_steps, conf.num_classes, conf.num_classes)
        x_cls_pool.append(x_cls) 
        y_cls_pool.append(y) 
        llrm_cls_pool.append(llrm_cls) 

    x_iter = np.concatenate(x_cls_pool, axis=0) # reshape into (NUM_CLASSES * BATCH_SIZE, TIME_STEPS, FEAT_DIM)
    y_iter = np.concatenate(y_cls_pool, axis=0) # reshape into (NUM_CLASSES * BATCH_SIZE)
    llrm_iter = np.concatenate(llrm_cls_pool, axis=0) # reshape into (NUM_CLASSES * BATCH_SIZE, TIME_STEPS, NUM_CLASSES, NUM_CLASSES)
    assert x_iter.shape == (conf.num_classes * conf.batch_size, conf.time_steps, conf.feat_dim) 
    assert y_iter.shape == (conf.num_classes * conf.batch_size,)
    assert llrm_iter.shape == (conf.num_classes * conf.batch_size, conf.time_steps, conf.num_classes, conf.num_classes)
    
    # accumulate evidence
    llrm_iter = np.cumsum(llrm_iter, axis=1)
    
    # create a data triplet
    data = (x_iter, y_iter, llrm_iter)
    
    write_transaction(env, data, names)
    
# Close the database
env.close()
logger.info(f'total data genarated:{conf.num_classes * conf.batch_size * conf.num_iter}')
logger.success("done and dusted!")




2023-03-28 07:38:39.745 | INFO     | __main__:<module>:29 - Starting iter_i=0 / 99
2023-03-28 07:38:54.070 | INFO     | __main__:<module>:29 - Starting iter_i=1 / 99
2023-03-28 07:39:07.219 | INFO     | __main__:<module>:29 - Starting iter_i=2 / 99
2023-03-28 07:39:20.027 | INFO     | __main__:<module>:29 - Starting iter_i=3 / 99
2023-03-28 07:39:33.139 | INFO     | __main__:<module>:29 - Starting iter_i=4 / 99
2023-03-28 07:39:46.894 | INFO     | __main__:<module>:29 - Starting iter_i=5 / 99
2023-03-28 07:40:00.362 | INFO     | __main__:<module>:29 - Starting iter_i=6 / 99
2023-03-28 07:40:14.046 | INFO     | __main__:<module>:29 - Starting iter_i=7 / 99
2023-03-28 07:40:27.631 | INFO     | __main__:<module>:29 - Starting iter_i=8 / 99
2023-03-28 07:40:41.520 | INFO     | __main__:<module>:29 - Starting iter_i=9 / 99
2023-03-28 07:40:56.956 | INFO     | __main__:<module>:29 - Starting iter_i=10 / 99
2023-03-28 07:41:10.831 | INFO     | __main__:<module>:29 - Starting iter_i=11 / 99
20

2023-03-28 08:04:32.522 | INFO     | __main__:<module>:29 - Starting iter_i=98 / 99
2023-03-28 08:04:47.264 | INFO     | __main__:<module>:29 - Starting iter_i=99 / 99
2023-03-28 08:05:02.560 | INFO     | __main__:<module>:72 - total data genarated:60000000
2023-03-28 08:05:02.561 | SUCCESS  | __main__:<module>:73 - done and dusted!
