In [None]:
# -*- coding: utf-8 -*-
import os,copy,csv,importlib,time,math
import numpy as np
import math,random,shutil
import matplotlib.pyplot as plt
from PIL import Image
from scipy import signal
import pickle as pickle
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc
from sklearn.cluster import KMeans, MeanShift, AgglomerativeClustering
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from sklearn import svm
from tqdm import tqdm
from sklearn.model_selection import GridSearchCV
from itertools import cycle
from matplotlib import cm
from scipy.stats import gaussian_kde
colors = cycle("bgrcmykbgrcmykbgrcmykbgrcmyk")

In [None]:
#########################################################################################
#set the seed for random environment
#here we set the seed to 45
#########################################################################################
def seed_everything(seed=45):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
#when import the module the seed is set
seed_everything()
    
#########################################################################################
#function: To read the prpd and waveform data from the .txt or .csv file
#input: prpd_name, waveform_name are the full path name of the data file
#output: numpy array of the prpd and waveform data
#     for the prpd, 2*n array, n is the number of the waveform point in prpd
#              row 0 is the amplitude of the prpd, row 1 is the phase of each point
#     for the waveform, n*d array, n is the number of the waveform, d is the length of
#              the waveform
#########################################################################################
def read_data(prpd_name, waveform_name):
    #prepare the raw data
    prpd_raw_data = []
    waveform_raw_data = []
    
    #get the file format
    file_format = prpd_name.split('.')[-1]
    file_format_tmp = waveform_name.split('.')[-1]
    assert file_format == file_format_tmp
    assert os.path.exists(prpd_name) and os.path.exists(waveform_name), 'PRPD or Waveform file not exist!'
    
    #read data flag
    read_flag = True
    
    #get prpd data
    prpd_ampl = list()
    prpd_angle = list()
    waveform_list = list()
    
    if file_format == 'txt':
        try:
            fopen = open(prpd_name)
            prpd_raw_data = fopen.readlines()
            fopen.close()
            for one_ampl in prpd_raw_data[0].split('\t'):
                prpd_ampl.append(np.float32(one_ampl))
            for one_angle in prpd_raw_data[1].split('\t'):
                prpd_angle.append(round(float(one_angle)))

            fopen = open(waveform_name)
            waveform_raw_data = fopen.readlines()
            fopen.close()
            for one_wave in waveform_raw_data:
                waveform_list.append(np.array(one_wave.split('\t'), dtype=np.float16))
            
            read_flag = True
            
        except:
            read_flag = False
        
    elif file_format == 'csv':
        try:
            fopen = open(prpd_name, 'r')
            csvfile = csv.reader(fopen)
            for i in csvfile:
                prpd_raw_data.append(i)
            for one_ampl in prpd_raw_data[0][0].split('\t'):
                prpd_ampl.append(float(one_ampl))
            for one_angle in prpd_raw_data[1][0].split('\t'):
                prpd_angle.append(round(float(one_angle)))
            fopen.close()

            fopen = open(waveform_name, 'r')
            csvfile = csv.reader(fopen)
            for one in csvfile:
                waveform_list.append(np.array(one[0].split('\t'), dtype=np.float16))
            fopen.close()
            read_flag = True
            
            prpd_list = np.array([prpd_ampl,prpd_angle])
            waveform_list = np.array(waveform_list, dtype=np.float16)
            if waveform_list.shape[0]!=len(prpd_angle):
                read_flag = False
                
        except:
            read_flag = False
            
        if read_flag==False:
            prpd_ampl = list()
            prpd_angle = list()
            waveform_list = list()
            waveform_raw_data = list()
            try:
                fopen = open(prpd_name, 'r')
                csvfile = csv.reader(fopen)
                for i in csvfile:
                    prpd_raw_data.append(i)
                prpd_ampl = np.array(prpd_raw_data, dtype=np.float16)[0]
                prpd_angle = np.array(prpd_raw_data, dtype=np.float16)[1]
                fopen.close()    
                
                fopen = open(waveform_name, 'r')
                csvfile = csv.reader(fopen)
                for one in csvfile:
                    waveform_list.append(np.array(one, dtype=np.float16))
                fopen.close()
                read_flag = True
                prpd_list = np.array([prpd_ampl,prpd_angle])
                waveform_list = np.array(waveform_list, dtype=np.float16)
                if waveform_list.shape[0]!=len(prpd_angle):
                    read_flag = False

            except:
                read_flag = False
            
    else:
        raise Exception('data file only support: txt or csv!')

    if read_flag==False:
        raise Exception('data file read error!')
    
    prpd_list = np.array([prpd_ampl,prpd_angle], dtype=np.float32).swapaxes(0,1)#trans to n*2 array
    waveform_list = np.array(waveform_list, dtype=np.float32)
    
    assert len(prpd_ampl)==len(prpd_angle)
    assert waveform_list.shape[0]==prpd_list.shape[0]
    assert prpd_list.shape[1]==2
    
    return prpd_list, waveform_list