## DATABASE CLASS
_Represent the database create from the dataset of samples_

In [1]:
from __future__ import division

import sys
import os
import h5py
import glob
import tqdm
import time 
from texttable import Texttable
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp

# import python library
sys.path.append(os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'py'))

from dataset import Dataset
from sample import Sample


class Database(Dataset):
    # Attributes
    
    # Initializer
    def __init__(self, dataset_location, database_location, site):
        self.__database_location = database_location
        super(Database, self).__init__(dataset_location, site)

    def get_database_location(self):
        return self.__database_location
    
    
    def create_cut_db(self, sample):
        '''Create a database from a single file'''
        sample_obj = Sample(self.get_dataset_location(), self.get_site(), sample)
        # cuts of 1s and 10kHz
        cuts, whistler_count, cuts_count = sample_obj.cuts(cut_time=1, cut_freq=10, threshold=0)
        for cut, ix in zip(cuts, range(cuts_count)) :
            file_name = os.path.join(self.__database_location,self.get_site(),
                                        os.path.splitext(sample)[0]+'.cut_nbr:'+"{:02d}".format(ix+1)+'.evt:'+str(ix<whistler_count)+'.['+str(cut[0])+':'+str(cut[1])+','+str(cut[2])+':'+str(cut[3])+'].h5')
            file = h5py.File(file_name, 'w')
            spec = sample_obj.get_spectrogram()[cut[0]:cut[1],cut[2]:cut[3]]
            file_dataset = file.create_dataset(file_name, spec.shape, np.float32, compression='gzip', data=spec)
            file_dataset.attrs['target'] = ix<whistler_count
            file.close()

    def create_cuts_db(self):
        ''''''
        samples = self.get_samples_name()
        try:
            os.makedirs(os.path.join(self.__database_location,self.get_site()))
        except OSError:
            pass
        for sample in tqdm.tqdm(samples):
            self.create_cut_db(sample)
          
            
    def create_cuts_db_mp(self, verbose=True):
        '''Parallel implementation of create_cuts_dp'''
        samples = self.get_samples_name()
        try:
            os.makedirs(os.path.join(self.__database_location,self.get_site()))
        except OSError:
            pass
        pool = mp.Pool(mp.cpu_count())
        if verbose:
            for _ in tqdm.tqdm(pool.imap_unordered(self.create_cut_db, samples), total=len(samples)):
                pass
        else:
            pool.map_async(self.create_cut_db, samples)
        pool.close()
    
    def load_cut_db(self, sample):
        '''Load one cut from the database'''
        file = h5py.File(sample, 'r+')
        file_data = np.empty(file[sample].shape)
        file[sample].read_direct(file_data)
        cut = np.asarray(file_data)
        target = file[sample].attrs['target']
        file.close()
        return cut, target
            
    
    def load_cuts_db(self):
        ''''''
        try:
            samples = glob.glob(os.path.join(self.__database_location,self.get_site(), '*.h5'))
        except OSError:
            return None, None
        cuts, targets = [], []
        for sample in tqdm.tqdm(samples):
            cut, target = self.load_cut_db(sample)
            cuts.append(cut)
            targets.append(target)
        return np.array(cuts), np.array(targets)

    def load_cuts_db_mp(self, verbose=True):
        ''''''
        try:
            samples = glob.glob(os.path.join(self.__database_location,self.get_site(), '*.h5'))
        except OSError:
            return None, None
        pool = mp.Pool(mp.cpu_count())
        if verbose:
            results = []
            for result in tqdm.tqdm(pool.imap_unordered(self.load_cut_db, samples), total=len(samples)):
                results.append(result)
            results = np.array(results)
        else:
            results = np.array(pool.map_async(self.load_cut_db, samples).get())
        pool.close()
        pool.join()
        return np.array(results[:,0]), np.array(results[:,1], dtype=np.bool_)
    
    def stats(self):
        '''Database stats'''
        cuts, targets = self.load_cuts_db_mp()
        temp = []
        temp.append([cut.flatten() for cut in cuts])
        cuts = np.array(temp)
        cuts = cuts.flatten()
        counts = np.bincount(targets)
        counts_per = np.round(np.bincount(targets)*100/len(targets),2)
        
        table = Texttable()
        table.set_deco(Texttable.HEADER)
        table.set_header_align(['l','m','m'])
        table.header(['Database statistics', '',''])
        table.set_cols_align(['l','l','l'])
        table.set_cols_valign(['m','m','m'])
        table.add_rows([
                ['min',cuts.min()],
                ['max',cuts.max()],
                ['mean',cuts.mean()],
                ['std',cuts.std()],
                ['noise', str(counts[0])+'['+str(counts_per[0])+'%]'],
                ['whistler', str(counts[1])+'['+str(counts_per[1])+'%]'],
                ['total', len(targets)]], header=False)
        print('\n'+ table.draw() + '\n')

In [2]:
# dataset_loc = os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'data','datasets', 'awdEvents1')
# database_loc = os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'data','databases', 'awdEvents1')
# site = 'marion'
# my_database = Database(dataset_loc, database_loc, site)

# my_database.create_cuts_db_mp(verbose=True)
# cuts, targets = my_database.load_cuts_db_mp()
# my_database.stats(verbose=True)

100%|██████████| 27004/27004 [02:52<00:00, 248.15it/s]



Database statistics                  
noise                 53.88%   14549 
whistler              46.12%   12455 
total                          27004 
min                            -4.673
max                            4.120 
mean                           1.445 
std                            0.519 



array([14549, 12455])