## DATABASE CLASS
_Represent the database create from the dataset of samples_

In [1]:
from __future__ import division

import sys
import os
import h5py
import glob
import tqdm
import parmap
import time 
import pickle
from itertools import repeat
from texttable import Texttable
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp
import imageio
from sklearn.model_selection import train_test_split

# import python library
sys.path.append(os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'py'))

from dataset import Dataset
from sample import Sample


class Database(Dataset):
    # Attributes
    __train_test_file = 'train_test.pickle'
    # Initializer
    def __init__(self, dataset_location, database_location, site):
        self.__database_location = database_location
#         super(Database, self).__init__(dataset_location, site)
        super().__init__(dataset_location, site)

    def get_database_location(self):
        return self.__database_location
    
    def create_param_db(self, sample):
        sample_obj = Sample(self.get_dataset_location(), self.get_site(), sample)
        file_name = os.path.join(self.__database_location,self.get_site(), self.get_site()+'.out')
        file = h5py.File(file_name, 'w')
        temp = np.array([]) 
        file_dataset = file.create_dataset(file_name, temp.shape, np.float32, compression='gzip', data=temp)
        for key, val in sample_obj.get_spectrogram_params().items():
            file_dataset.attrs[key] = val        
        file.close()
    
    def train_test_split(self ,train_size=None, test_size=None, random_state=None, shuffle=True, save=False):
        self.__train, self.__test = train_test_split(self.get_samples(),train_size=train_size, test_size=test_size, random_state=random_state, shuffle=shuffle)
        if save:
            samples = {
                'train': self.__train,
                'test': self.__test
            }
        path = os.path.join(self.__database_location,self.get_site(), self.get_site()+'_h5')
        try:
            os.makedirs(path)
        except OSError:
            pass
        pickle.dump(samples, open(self.__train_test_file, 'wb'))
    
    def train_test_load(self):
        path = os.path.join(self.__database_location,self.get_site(), self.get_site()+'_h5')
        samples = pickle.load( open( os.path.join(path, self.__train_test_file), "rb" ) )
        self.__train = samples['train']
        self.__test = samples['test']
    
    def get_train(self):
        return self.__train
    
    def get_test(self):
        return self.__test
        
    def create_cut_img_db(self, sample, zscore=False, medfilt=False, kernel=(3,3), noise=True):
        '''Create jpeg cut files from a single file'''
        sample_obj = Sample(self.get_dataset_location(), self.get_site(), sample)
        if zscore:
            sample_obj.apply_zscore()
        if medfilt:
            sample_obj.apply_medfilt(kernel=kernel)
        # cuts of 1s and 10kHz
        cuts, whistler_count, cuts_count = sample_obj.cuts(cut_time=1, cut_freq=10, threshold=0, noise=noise)
    
        sample_obj.to_img()
        img = sample_obj.get_image()
        height = img.size[1]
        for cut, ix in zip(cuts, range(cuts_count)) :
            file_name = os.path.join(self.__database_location,self.get_site(), self.get_site()+'_image_data',
                                        os.path.splitext(sample)[0]+'.cut_nbr:'+"{:02d}".format(ix+1)+'.evt:'+str(ix<whistler_count)+'.['+str(cut[0])+':'+str(cut[1])+','+str(cut[2])+':'+str(cut[3])+'].jpeg')
            spec = img.crop(box=(cut[2],height-cut[1],cut[3],height-cut[0]))
            plt.figure();plt.imshow(spec);plt.show();
            spec.save(file_name)
    
    def create_cut_img_db_(self, args):
        sample, zscore, medfilt, kernel, noise = args[0], args[1], args[2], args[3], args[4]
        self.create_cut_img_db(sample,zscore=zscore, medfilt=medfilt, kernel=kernel)
        
    def create_cuts_img_db_mp(self, verbose=True, zscore=False, medfilt=False, kernel=(3,3), noise=True):
        '''Parallel implementation of create_cuts_dp'''
        samples = self.get_samples()
        samples_len = len(samples)
        try:
            os.makedirs(os.path.join(self.__database_location,self.get_site(), self.get_site()+'_image_data'))
        except OSError:
            pass
        pool = mp.Pool(mp.cpu_count())
        if verbose:
            # wrap arguments and use create_cut_db_ instead of create_cut_db
            samples = [[sample, zscore, medfilt, kernel, noise] for sample in samples]
            for _ in tqdm.tqdm(pool.imap_unordered(self.create_cut_img_db_, samples), total=len(samples)):
                pass
        else:
            pool.map_async(self.create_cut_img_db, samples, zscore=zscore, medfilt=medfilt, kernel=kernel, noise=noise)
        pool.close()
        # create params output file
        self.create_param_db(samples[0])
        
    def create_cut_db(self, sample, zscore=False, medfilt=False, kernel=(3,3), noise=True):
        '''Create a database from a single file'''
        sample_obj = Sample(self.get_dataset_location(), self.get_site(), sample)
        if zscore:
            sample_obj.apply_zscore()
        if medfilt:
            sample_obj.apply_medfilt(kernel=kernel)
        # cuts of 1s and 10kHz
        cuts, whistler_count, cuts_count = sample_obj.cuts(cut_time=1, cut_freq=10, threshold=0, noise=noise)
        ### for images
        image = True
        if image:
            sample_obj.to_img()
            img = sample_obj.get_image()
            plt.figure();plt.imshow(img);plt.show();
        for cut, ix in zip(cuts, range(cuts_count)) :
            file_name = os.path.join(self.__database_location,self.get_site(), self.get_site()+'_data',
                                        os.path.splitext(sample)[0]+'.cut_nbr:'+"{:02d}".format(ix+1)+'.evt:'+str(ix<whistler_count)+'.['+str(cut[0])+':'+str(cut[1])+','+str(cut[2])+':'+str(cut[3])+'].h5')
            file = h5py.File(file_name, 'w')
            if image:
                height = img.size[1]
                spec = img.crop(box=(cut[2],height-cut[1],cut[3],height-cut[0]))
                plt.figure();plt.imshow(spec);plt.show();
                file_dataset = file.create_dataset(file_name, spec.size, np.float32, compression='gzip', data=spec)
            else:
                spec = sample_obj.get_spectrogram()[cut[0]:cut[1],cut[2]:cut[3]]
                file_dataset = file.create_dataset(file_name, spec.shape, np.float32, compression='gzip', data=spec)
            file_dataset.attrs['target'] = ix<whistler_count
            file.close()
    
    def create_cut_db_(self, args):
        sample, zscore, medfilt, kernel, noise = args[0], args[1], args[2], args[3], args[4]
        self.create_cut_db(sample,zscore=zscore, medfilt=medfilt, kernel=kernel, noise=noise)
    
    def create_cuts_db(self, zscore=False, medfilt=False, kernel=(3,3), noise=True):
        ''''''
        samples = self.get_samples()
        try:
            os.makedirs(os.path.join(self.__database_location,self.get_site(), self.get_site()+'_data'))
        except OSError:
            pass
        for sample in tqdm.tqdm(samples):
            self.create_cut_db(sample, zscore=zscore, medfilt=medfilt, kernel=kernel, noise=noise)
        # create params output file
        self.create_param_db(samples[0])

    def create_cuts_db_mp(self, verbose=True, zscore=False, medfilt=False, kernel=(3,3), noise=True):
        '''Parallel implementation of create_cuts_dp'''
        samples = self.get_samples()
        samples_len = len(samples)
        try:
            os.makedirs(os.path.join(self.__database_location,self.get_site(), self.get_site()+'_data'))
        except OSError:
            pass
        pool = mp.Pool(mp.cpu_count())
        if verbose:
            # wrap arguments and use create_cut_db_ instead of create_cut_db
            samples = [[sample, zscore, medfilt, kernel, noise] for sample in samples]
            for _ in tqdm.tqdm(pool.imap_unordered(self.create_cut_db_, samples), total=len(samples)):
                pass
        else:
            pool.map_async(self.create_cut_db, samples, zscore=zscore, medfilt=medfilt, kernel=kernel, noise=noise)
        pool.close()
        # create params output file
        self.create_param_db(samples[0])
    
    def load_cut_db(self, sample):
        '''Load one cut from the database'''
        file = h5py.File(sample, 'r+')
        file_data = np.empty(file[sample].shape, dtype=np.uint8)
        file[sample].read_direct(file_data)
        cut = np.asarray(file_data)
        target = file[sample].attrs['target']
        file.close()
        return cut, target
            
    
    def load_cuts_db(self):
        ''''''
        try:
            samples = glob.glob(os.path.join(self.__database_location, self.get_site(), self.get_site()+'_data', '*.h5'))
        except OSError:
            return None, None
        cuts, targets = [], []
        for sample in tqdm.tqdm(samples):
            cut, target = self.load_cut_db(sample)
            cuts.append(cut)
            targets.append(target)
        return np.array(cuts), np.array(targets)

    def load_cuts_db_mp(self, verbose=True):
        ''''''
        try:
            samples = glob.glob(os.path.join(self.__database_location,self.get_site(), self.get_site()+'_data', '*.h5'))
        except OSError:
            return None, None
        pool = mp.Pool(mp.cpu_count())
        if verbose:
            results = []
            for result in tqdm.tqdm(pool.imap_unordered(self.load_cut_db, samples), total=len(samples)):
                results.append(result)
            results = np.array(results)
        else:
            results = np.array(pool.map_async(self.load_cut_db, samples).get())
        pool.close()
        pool.join()
#         shape = np.unique([X.shape for X in results[:,0]])
        return np.array(list(results[:,0])), np.array(results[:,1], dtype=np.bool_)
    
    def load_cuts_params(self):
        file_name = os.path.join(self.__database_location,self.get_site(), self.get_site()+'.out')
        file = h5py.File(file_name, 'r+')
        params = {}
        for key,val in file[file_name].attrs.items():
            params[key] = val
        file.close()
        return params
    
    def stats(self):
        '''Database stats'''
        cuts, targets = self.load_cuts_db_mp()
        temp_cuts = []
        temp_cuts.append([cut.flatten() for cut in cuts])
        temp_cuts = np.array(temp_cuts).flatten()
        counts = np.bincount(targets)
        counts_per = np.round(np.bincount(targets)*100/len(targets),2)
        
        table = Texttable()
        table.set_deco(Texttable.HEADER)
        table.set_header_align(['l','m'])
        table.header(['Database statistics', ''])
        table.set_cols_align(['l','l'])
        table.set_cols_valign(['m','m'])
        table.add_rows([
                ['min',temp_cuts.min()],
                ['max',temp_cuts.max()],
                ['mean',temp_cuts.mean()],
                ['std',temp_cuts.std()],
                ['noise', str(counts[0])+'['+str(counts_per[0])+'%]'],
                ['whistler', str(counts[1])+'['+str(counts_per[1])+'%]'],
                ['total', len(targets)]], header=False)
        print('\n'+ table.draw() + '\n')
        return cuts, targets

In [None]:
# from sklearn import model_selection,preprocessing
dataset_loc = os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'data','datasets', 'awdEvents1')
database_loc = os.path.join(os.getcwd().split(os.environ.get('USER'))[0],os.environ.get('USER'), 'wdml', 'data','databases', 'awdEvents1')
site = 'marion'
my_database = Database(dataset_loc, database_loc, site)

# file = '2013-05-20UT16:13:33.90782156.marion.vr2'
# my_sample = Sample(dataset_loc, site, file)
# my_sample.spectrogram_plot(figsize=(15,5))

# sample = my_database.get_random_sample()
# os.makedirs(os.path.join(my_database.get_database_location(),my_database.get_site(), my_database.get_site()+'_data'))
# my_database.create_cut_db(file)
my_database.create_cuts_db_mp(zscore=True, medfilt=True, kernel=(9,9), noise=False)

# cuts, targets = my_database.load_cuts_db_mp()
# for cut in cuts:
#     plt.figure()
#     plt.imshow(cut, cmap='jet')
# plt.show()

 99%|█████████▉| 2173/2196 [14:50<00:10,  2.20it/s]