In [None]:
# -*- coding: utf-8 -*-
import os,copy,csv,importlib,time,math
import numpy as np
import math,random,shutil
import matplotlib.pyplot as plt
from PIL import Image
from scipy import signal
from scipy.signal import cwt, ricker
import pickle as pickle
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc
from sklearn.cluster import KMeans, MeanShift, AgglomerativeClustering
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from sklearn import svm
from tqdm import tqdm
from sklearn.model_selection import GridSearchCV
from itertools import cycle
from matplotlib import cm
from scipy.stats import gaussian_kde
colors = cycle("bgrcmykbgrcmykbgrcmykbgrcmyk")
import sys
csv.field_size_limit(sys.maxsize)
import glob
import pandas as pd
import pywt
from sklearn.metrics import pairwise_distances_argmin_min



In [None]:
#########################################################################################
#set the seed for random environment
#here we set the seed to 45
#########################################################################################
def seed_everything(seed=45):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
#when import the module the seed is set
seed_everything()
    
#########################################################################################
#function: To read the prpd and waveform data from the .txt or .csv file
#input: prpd_name, waveform_name are the full path name of the data file
#output: numpy array of the prpd and waveform data
#     for the prpd, 2*n array, n is the number of the waveform point in prpd
#              row 0 is the amplitude of the prpd, row 1 is the phase of each point
#     for the waveform, n*d array, n is the number of the waveform, d is the length of
#              the waveform
#########################################################################################
def read_data(prpd_name, waveform_name):
    #prepare the raw data
    prpd_raw_data = []
    waveform_raw_data = []
    
    #get the file format
    file_format = prpd_name.split('.')[-1]
    file_format_tmp = waveform_name.split('.')[-1]
    assert file_format == file_format_tmp
    assert os.path.exists(prpd_name) and os.path.exists(waveform_name), 'PRPD or Waveform file not exist!'
    
    #read data flag
    read_flag = True
    
    #get prpd data
    prpd_ampl = list()
    prpd_angle = list()
    waveform_list = list()
    
    if file_format == 'txt':
        try:
            fopen = open(prpd_name)
            prpd_raw_data = fopen.readlines()
            fopen.close()
            for one_ampl in prpd_raw_data[0].split('\t'):
                prpd_ampl.append(np.float32(one_ampl))
            for one_angle in prpd_raw_data[1].split('\t'):
                prpd_angle.append(round(float(one_angle)))

            fopen = open(waveform_name)
            waveform_raw_data = fopen.readlines()
            fopen.close()
            for one_wave in waveform_raw_data:
                waveform_list.append(np.array(one_wave.split('\t'), dtype=np.float16))
            
            read_flag = True
            
        except Exception as e:
            read_flag = False
            raise e
        
    elif file_format == 'csv':
        try:
            fopen = open(prpd_name, 'r')
            csvfile = csv.reader(fopen)
            for i in csvfile:
                prpd_raw_data.append(i)
            for one_ampl in prpd_raw_data[0][0].split('\t'):
                prpd_ampl.append(float(one_ampl))
            for one_angle in prpd_raw_data[1][0].split('\t'):
                prpd_angle.append(round(float(one_angle)))
            fopen.close()

            fopen = open(waveform_name, 'r')
            csvfile = csv.reader(fopen)
            for one in csvfile:
                #skip the header
                if one[0]=='WaveformData[0]	WaveformData[1]	WaveformData[2]	WaveformData[3]	WaveformData[4]	WaveformData[5]	WaveformData[6]	WaveformData[7]	WaveformData[8]	WaveformData[9]	WaveformData[10]	WaveformData[11]	WaveformData[12]	WaveformData[13]	WaveformData[14]	WaveformData[15]	WaveformData[16]	WaveformData[17]	WaveformData[18]	WaveformData[19]	WaveformData[20]	WaveformData[21]	WaveformData[22]	WaveformData[23]	WaveformData[24]	WaveformData[25]	WaveformData[26]	WaveformData[27]	WaveformData[28]	WaveformData[29]	WaveformData[30]	WaveformData[31]	WaveformData[32]	WaveformData[33]	WaveformData[34]	WaveformData[35]	WaveformData[36]	WaveformData[37]	WaveformData[38]	WaveformData[39]	WaveformData[40]	WaveformData[41]	WaveformData[42]	WaveformData[43]	WaveformData[44]	WaveformData[45]	WaveformData[46]	WaveformData[47]	WaveformData[48]	WaveformData[49]	WaveformData[50]	WaveformData[51]	WaveformData[52]	WaveformData[53]	WaveformData[54]	WaveformData[55]	WaveformData[56]	WaveformData[57]	WaveformData[58]	WaveformData[59]	WaveformData[60]	WaveformData[61]	WaveformData[62]	WaveformData[63]	WaveformData[64]	WaveformData[65]	WaveformData[66]	WaveformData[67]	WaveformData[68]	WaveformData[69]	WaveformData[70]	WaveformData[71]	WaveformData[72]	WaveformData[73]	WaveformData[74]	WaveformData[75]	WaveformData[76]	WaveformData[77]	WaveformData[78]	WaveformData[79]	WaveformData[80]	WaveformData[81]	WaveformData[82]	WaveformData[83]	WaveformData[84]	WaveformData[85]	WaveformData[86]	WaveformData[87]	WaveformData[88]	WaveformData[89]	WaveformData[90]	WaveformData[91]	WaveformData[92]	WaveformData[93]	WaveformData[94]	WaveformData[95]	WaveformData[96]	WaveformData[97]	WaveformData[98]	WaveformData[99]	WaveformData[100]	WaveformData[101]	WaveformData[102]	WaveformData[103]	WaveformData[104]	WaveformData[105]	WaveformData[106]	WaveformData[107]	WaveformData[108]	WaveformData[109]	WaveformData[110]	WaveformData[111]	WaveformData[112]	WaveformData[113]	WaveformData[114]	WaveformData[115]	WaveformData[116]	WaveformData[117]	WaveformData[118]	WaveformData[119]	WaveformData[120]	WaveformData[121]	WaveformData[122]	WaveformData[123]	WaveformData[124]	WaveformData[125]	WaveformData[126]	WaveformData[127]	WaveformData[128]	WaveformData[129]	WaveformData[130]	WaveformData[131]	WaveformData[132]	WaveformData[133]	WaveformData[134]	WaveformData[135]	WaveformData[136]	WaveformData[137]	WaveformData[138]	WaveformData[139]	WaveformData[140]	WaveformData[141]	WaveformData[142]	WaveformData[143]	WaveformData[144]	WaveformData[145]	WaveformData[146]	WaveformData[147]	WaveformData[148]	WaveformData[149]	WaveformData[150]	WaveformData[151]	WaveformData[152]	WaveformData[153]	WaveformData[154]	WaveformData[155]	WaveformData[156]	WaveformData[157]	WaveformData[158]	WaveformData[159]	WaveformData[160]	WaveformData[161]	WaveformData[162]	WaveformData[163]	WaveformData[164]	WaveformData[165]	WaveformData[166]	WaveformData[167]	WaveformData[168]	WaveformData[169]	WaveformData[170]	WaveformData[171]	WaveformData[172]	WaveformData[173]	WaveformData[174]	WaveformData[175]	WaveformData[176]	WaveformData[177]	WaveformData[178]	WaveformData[179]	WaveformData[180]	WaveformData[181]	WaveformData[182]	WaveformData[183]	WaveformData[184]	WaveformData[185]	WaveformData[186]	WaveformData[187]	WaveformData[188]	WaveformData[189]	WaveformData[190]	WaveformData[191]	WaveformData[192]	WaveformData[193]	WaveformData[194]	WaveformData[195]	WaveformData[196]	WaveformData[197]	WaveformData[198]	WaveformData[199]	WaveformData[200]	WaveformData[201]	WaveformData[202]	WaveformData[203]	WaveformData[204]	WaveformData[205]	WaveformData[206]	WaveformData[207]	WaveformData[208]	WaveformData[209]	WaveformData[210]	WaveformData[211]	WaveformData[212]	WaveformData[213]	WaveformData[214]	WaveformData[215]	WaveformData[216]	WaveformData[217]	WaveformData[218]	WaveformData[219]	WaveformData[220]	WaveformData[221]	WaveformData[222]	WaveformData[223]	WaveformData[224]	WaveformData[225]	WaveformData[226]	WaveformData[227]	WaveformData[228]	WaveformData[229]	WaveformData[230]	WaveformData[231]	WaveformData[232]	WaveformData[233]	WaveformData[234]	WaveformData[235]	WaveformData[236]	WaveformData[237]	WaveformData[238]	WaveformData[239]	WaveformData[240]	WaveformData[241]	WaveformData[242]	WaveformData[243]	WaveformData[244]	WaveformData[245]	WaveformData[246]	WaveformData[247]	WaveformData[248]	WaveformData[249]	WaveformData[250]	WaveformData[251]	WaveformData[252]	WaveformData[253]	WaveformData[254]	WaveformData[255]':
                    continue
                waveform_list.append(np.array(one[0].split('\t'), dtype=np.float16))
            fopen.close()
            read_flag = True
            
            prpd_list = np.array([prpd_ampl,prpd_angle])
            waveform_list = np.array(waveform_list, dtype=np.float16)
            if waveform_list.shape[0]!=len(prpd_angle):
                read_flag = False
                
        except Exception as e:
            read_flag = False
            print(e)
            raise e
            
        if read_flag==False:
            prpd_ampl = list()
            prpd_angle = list()
            waveform_list = list()
            waveform_raw_data = list()
            try:
                fopen = open(prpd_name, 'r')
                csvfile = csv.reader(fopen)
                for i in csvfile:
                    prpd_raw_data.append(i)
                prpd_ampl = np.array(prpd_raw_data, dtype=np.float16)[0]
                prpd_angle = np.array(prpd_raw_data, dtype=np.float16)[1]
                fopen.close()    
                
                fopen = open(waveform_name, 'r')
                csvfile = csv.reader(fopen)
                for one in csvfile:
                    waveform_list.append(np.array(one, dtype=np.float16))
                fopen.close()
                read_flag = True
                prpd_list = np.array([prpd_ampl,prpd_angle])
                waveform_list = np.array(waveform_list, dtype=np.float16)
                if waveform_list.shape[0]!=len(prpd_angle):
                    read_flag = False

            except Exception as e:
                read_flag = False
                raise e

    else:
        raise Exception('data file only support: txt or csv!')

    if read_flag==False:
        raise Exception('data file read error!')
    
    prpd_list = np.array([prpd_ampl,prpd_angle], dtype=np.float32).swapaxes(0,1)#trans to n*2 array
    waveform_list = np.array(waveform_list, dtype=np.float32)
    
    assert len(prpd_ampl)==len(prpd_angle)
    assert waveform_list.shape[0]==prpd_list.shape[0]
    assert prpd_list.shape[1]==2
    
    return prpd_list, waveform_list

In [None]:
def extract_fea(waveform_data):

    #waveform_num and feadim
    waveform_data = np.array(waveform_data, dtype=np.float32)
    waveform_num, feadim = waveform_data.shape
    scales = np.arange(1, 129) 
    
    #normalize the feature
    waveform_data_norm = preprocessing.normalize(waveform_data, norm='l2',axis=1)    

    #get wavelets feature
    
    # initialize a list to store all waveform features
    all_waveform_features = []
    
    for waveform in waveform_data_norm:
        # apply continuous wavelet transform
        cwt_coeffs = cwt(waveform, ricker, scales)
        
        # flatten the 2D array of CWT coefficients
        waveform_features = cwt_coeffs.flatten()
        
        # append the features to the list
        all_waveform_features.append(waveform_features)
    
    # convert the list of features to a Numpy array
    waveform_features_array = np.array(all_waveform_features, dtype=np.float32)
    
    return waveform_features_array

'''
def extract_fea(waveform_data):

    #waveform_num and feadim
    waveform_data = np.array(waveform_data, dtype=np.float32)
    waveform_num, feadim = waveform_data.shape
    
    #normalize the feature
    waveform_data_norm = preprocessing.normalize(waveform_data, norm='l2',axis=1)    

    #get wavelets feature
    coeffs = pywt.dwt(waveform_data, 'db1')
    cA, *cD = coeffs  # cA: Approximation, cD: Detail
    cA = np.array(cA, dtype=np.float32)
    cD = np.array(cD, dtype=np.float32)
    waveform_fea = np.column_stack((cA, cD))
    return waveform_fea
'''

'''
def extract_fea(waveform_data):
    waveform_num, feadim = waveform_data.shape
    
    # initialize a list to store all waveform features
    waveform_features = []
    
    for i in range(waveform_num):
        # apply discrete wavelet transform
        coeffs = pywt.wavedec(waveform_data[i, :], 'db4', level=4)
        cA, *cD = coeffs
        
        # flatten the approximation and detail coefficients
        features = np.hstack([cA, np.hstack(cD)])
        
        # append the features to the list
        waveform_features.append(features)
    
    # convert the list of features to a Numpy array
    waveform_features = np.array(waveform_features, dtype=np.float32)
    waveform_fea = preprocessing.normalize(waveform_features, norm='l2', axis=1)
    
    return waveform_fea
'''



def pd_cluster(waveform_list, cluster_num=3):
    #extract feature from the waveforms
    waveform_fea = extract_fea(waveform_list)
    
    #call the kmeans
    kmeans = KMeans(n_clusters=cluster_num)
    clusteridx = kmeans.fit_predict(waveform_fea)
    
    return clusteridx, kmeans, waveform_fea

def cluster_prpd(prpd_data, waveform_data, cluster_num=3):
    waveform_num = waveform_data.shape[0]
    feadim = waveform_data.shape[1]
    test_cluster_num = 5
    
    #if cluster_num is 0, then use the autokmeans
    clusteridx, kmeans, waveform_fea = pd_cluster(waveform_data, cluster_num)
    
    # find the closest point in each cluster to its centroid
    closest_points = []
    closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, waveform_fea)
    for i in range(cluster_num):
        closest_points.append(waveform_data[closest[i]])

    #construct the cluster results
    cluster_result = []
    for idx in range(cluster_num):
        #get each cluster data
        one_cluster_prpd = prpd_data[clusteridx==idx, :]
        prpd_ampl_cluster = one_cluster_prpd[:,0]
        prpd_angle_cluster = one_cluster_prpd[:,1]
        
        #get the basic info of the cluster
        pluse_num = prpd_ampl_cluster.shape[0]

        ###################stepped sin for corr##################
        corr_max = 0
        step_num = 36
        for step_idx in range(step_num):
            sin_prpr_cluster = np.sin(2*np.pi/360*(prpd_angle_cluster+step_idx*360/step_num))

            #calc the corr
            corr_tmp = abs(round(np.corrcoef(sin_prpr_cluster,prpd_ampl_cluster)[0][1],3))
            if corr_tmp > corr_max:
                corr_max = corr_tmp
        #########################################################
        
        #save the results
        one_cluster_res = {'prpd':one_cluster_prpd, 'corr':corr_max}
        cluster_result.append(one_cluster_res)
        
    return cluster_result, clusteridx, closest_points

In [None]:
def update_fig(prpd_list, max_ampl,f1,counter):        

    #############################################update the figure################################################
    #get the figure row and col, prepare for plot the prpd and detected prpd
    #fig = plt.figure(figsize=(3,2),dpi=300)#figsize=[width, height]
    plt.clf()
    fontsize = 4
    ax = plt.subplot(1,1,1)
    plt.grid(linestyle='-.',linewidth=0.2)
    plt.rcParams['font.size']=fontsize  #font size

    sin_x = np.arange(0,360+1,1)
    sin_y = np.arange(0,360+1,1.0)
    zero_y = np.zeros(361)
    for i in range(len(sin_x)):
        sin_y[i] = math.sin(2*math.pi/360*sin_x[i])*max_ampl

    plt.plot(sin_x, zero_y, color='grey',linewidth=0.2)
    plt.plot(sin_x, sin_y, color='grey',linewidth=0.2)

    sc = plt.scatter(prpd_list[:,1], prpd_list[:,0],
                     s=0.1, marker=',', cmap=cm.jet)
    plt.axis([0,360,-max_ampl,max_ampl])
    plt.xticks(size=fontsize)
    plt.yticks(size=fontsize)
    plt.xlabel('Phase',fontsize=fontsize)
    plt.ylabel('Peak',fontsize=fontsize)
    plt.title(f'PRPD cluster{counter}',fontsize=fontsize)

    '''
    method = {'wavelet','benchmark'}
    dir_name = {'benchmark_n','cwt_n','wavelet_n'}
    '''
    method = 'wavelet'
    dir_name = 'cwt_5'

    #f1 = PATH/prpd/*.csv
    f1 = f1.replace('.csv',f'_{counter}.png')
    #f1 = PATH/prpd/*._{counter}.png
    f1 = f1.replace('data/','Features/')
    f1 = f1.replace('prpd/',f'K-Means/{method}/{dir_name}/prpd_figure/')
    plt.savefig(f1, dpi=300)
    return True

In [None]:
# update waveform fig
def update_waveform_fig(closest_points,f2):
    plt.clf()
    plt.figure(figsize=(6,4))

    for i, waveform in enumerate(closest_points):
        plt.plot(waveform, label=f'Cluster {i}', linewidth=0.5)

    plt.legend(title='Representative Waveform in Each Cluster', fontsize=8)

    #f2 = PATH/wavelet/WaveForm_idx/*_idx.csv
    f2=f2.replace("_idx.csv",".png")
    #f2 = PATH/wavelet/WaveForm_idx/*.png
    f2=f2.replace("WaveForm_idx/","WaveForm_figure/")
    #f2 = PATH/wavelet/WaveForm_figure/*.png
    plt.savefig(f2)

    # 显示图形
    plt.show()
    return True
        

In [None]:
#test the function
data_basepath = '/media/mldadmin/home/s123mdg34_04/WangShengyuan/FYP/data/corona/'
prpd_filepath = data_basepath+'prpd/*.csv'
waveform_filepath = data_basepath+'WaveForm/*.csv'



In [None]:
cluster_num = 5
for f1, f2 in zip(sorted(glob.glob(prpd_filepath)), 
                  sorted(glob.glob(waveform_filepath))):
    
    #prpd_filepath = f1 = /PATH/prpd/*.csv
    #waveform_filepath = f2 = /PATH/WaveForm/*.csv

    prpd_list, waveform_list = read_data(f1, f2)
    max_ampl = np.max(np.abs(prpd_list[:,0]))
    cluster_res, clusteridx, closest_points = cluster_prpd(prpd_list, waveform_list, cluster_num)
    clusteridx = np.array(clusteridx, dtype=np.int32)
    counter = 0

    #save prpd cluster figure
    for onecluster in cluster_res:
        update_fig(onecluster['prpd'], max_ampl,f1,counter)
        counter += 1

    '''
    method = {'wavelet','benchmark}
    dir_name = {'benchmark_n','cwt_n','wavelet_n'}
    '''
    method = 'wavelet'
    dir_name = 'cwt_5'
    
    #save waveform index csv
    f2=f2.replace(".csv","_idx.csv")
    #f2 = PATH/WaveForm/*_idx.csv
    f2=f2.replace("data/","Features/")
    f2=f2.replace("WaveForm/",f"K-Means/{method}/{dir_name}/WaveForm_idx/")
    #f2 = PATH/wavelet/WaveForm_idx/*_idx.csv
    np.savetxt(f2,clusteridx,delimiter='\tab')

    #save waveform figure
    update_waveform_fig(closest_points,f2)

    