In [9]:
###############################################
#           H E A D E R   F I L E S
###############################################
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
from Prj_XPS2CurveChart_01 import GetPaths

###############################################
#          F U N C T I O N   L I S T
###############################################
## @brief Description: read the csv file and return data
#  @param [in] file : csv file path
#  
#  @return data (y coordinates)
#  @date 20230309  danielwu
def ReadCSV(file):
    data = []
    
    # open csv file
    with open(file, newline='') as f:
        # read csv file
        rows = list(csv.reader(f))
        
        for row in rows:
            data.append(float(row[1]))
    
    return data

## @brief Description: plot curve chart and save it
#  @param [in] files : csv file path
#  @param [in] freq : frequency
#  
#  @return None
#  @date 20230309  danielwu
def GroupPlot(files, freq):
    ys_coors = []
    labels = []
    
    # set x coordinate
    x_coors = list(range(-200,600))
    
    for file in files:
        # record the label
        label = file.split(ROOT)[1].split('\\')[1]
        if label not in labels:
            labels.append(label)
        
        # set y coordinate
        ys_coors.append(ReadCSV(file))
    
    # set plot parameter
    font1 = {'family':'serif', 'color':'blue', 'size':16}
    font2 = {'family':'serif', 'color':'black', 'size':12}
    colors = ['#AD2811', '#B90CFA', '#00FA00']
    
    plt.title(f'Average Responses for {freq}', **font1)
    plt.xlabel('ms', loc='right', **font2)
    plt.ylabel('\u03BCV', loc='top', **font2)
    plt.xticks(np.linspace(-200,600,9), **font2)
    plt.yticks(np.linspace(-10,10,9), ['', '', '-5.0', '-2.5', '65dB\nHL', '2.5', '5.0', '', ''], **font2)
    plt.xlim([-201, 601])
    # plt.ylim([-15, 15])
    plt.grid(axis='y', color='gray', linewidth='.2')
    
    # plot curve charts
    for y_coors, label, color in zip(ys_coors, labels, colors):
        plt.plot(x_coors, y_coors, label=label, color=color, linewidth='1.')
    
    # show the label
    plt.legend()
    
    # save curve chart
    output_dir = os.path.join(ROOT, 'output')
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    save_path = os.path.join(output_dir, f'out_{freq}.png')
    plt.savefig(save_path, dpi=300)
    
#     plt.show()
    plt.close()
###############################################
#             D A T A   T Y P E S
###############################################

###############################################
#              C O N S T A N T S
###############################################

###############################################
#        G L O B A L   V A R I A B L E
###############################################
# set root path
ROOT = 'C:\\Users\\danielwu\\Desktop\\ee\\xps_ALL'

###############################################
#                   M A I N
###############################################
if __name__ == '__main__':
    # get output csv files
    files = list(filter(lambda x: 'output' in x, GetPaths(ROOT, '.csv')))
    
    # classified by frequency
    freq_4k_files = list(filter(lambda x: '4kHz' in x, files))
    freq_2k_files = list(filter(lambda x: '2kHz' in x, files))
    freq_1k_files = list(filter(lambda x: '1kHz' in x, files))
    freq_500_files = list(filter(lambda x: '500Hz' in x, files))
    
    GroupPlot(freq_4k_files, '4kHz')
    GroupPlot(freq_2k_files, '2kHz')
    GroupPlot(freq_1k_files, '1kHz')
    GroupPlot(freq_500_files, '500Hz')
    
    print('done')

done
