### The purpose of the manual classification code is to provide to the CNN model more images to work with. The CNN's accuracy will improve when it is fed with images that it hasn't processed already.

In [None]:
from bat import extract_anabat
import csv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import glob
import csv
import gc
from scipy.signal import savgol_filter
from numpy.polynomial.polynomial import polyfit

Below functions are borrowed from `Hadi_bulk_pulse_processing` notebook:

In [None]:
def clean_graph(filename, graph=None, dy_cutoff = 2000, dx_cutoff = .2, pulse_size = 20):
    if graph is None:
        # Load file into 2d list
        with open(filename, 'r') as f:
            reader = csv.reader(f)
            next(reader)
            zc_str = list(reader)
    else:
        zc_str = graph
        
    zc_x, zc_y = graph[0], graph[1]
    
    '''
    # Format zc_str to floats
    zc_x = list()
    zc_y = list()
    for x, y in zc_str:
            zc_x.append(float(x) * 1000)
            zc_y.append(float(y))
    '''
    
    # Identify pulses
    graph = list()
    pulse = list()
    prev_x = 0
    for x, y in zip(zc_x, zc_y):
        if x - prev_x <= dx_cutoff:
            pulse.append([x, y])
        elif len(pulse) < pulse_size:
            pulse = [[x, y]]
        else:
            graph.append(pulse)
            pulse = [[x, y]]
        prev_x = x

    # Get 1st derivative
    graph_dy = list()
    prev_y = 0
    for pulse in graph:
        dy = list()
        for x, y in pulse:
            dy.append(abs(y - prev_y))
            prev_y = y
        graph_dy.append(dy)

    # Smooth holes
    for dy, pulse in zip(graph_dy, graph):
        i = 1
        while i < (len(dy) - 2):
            if dy[i] > dy_cutoff:
                if dy[i - 1] < dy_cutoff:
                    if dy[i + 1] < dy_cutoff:
                        pulse[i][1] = (pulse[i - 1][1] + pulse[i + 1][1])/2
                    elif dy[i + 2] < dy_cutoff:
                        pulse[i][1] = (pulse[i - 1][1] + pulse[i + 2][1])/2
                elif dy[i - 2] < dy_cutoff:
                    if dy[i + 1] < dy_cutoff:
                        pulse[i][1] = (pulse[i - 2][1] + pulse[i + 1][1])/2
                    elif dy[i + 2] < dy_cutoff:
                        pulse[i][1] = (pulse[i - 2][1] + pulse[i + 2][1])/2
            i += 1

    # Clean pulses
    clean_graph = list()
    for k, pulse in enumerate(graph):
        i = 1
        while i < len(pulse):
            j = i

            # Count neighboring points
            while j < len(pulse) - 1 and graph_dy[k][j] <= dy_cutoff:
                j += 1

            # If there are enough neighbors, it's good
            if j - i >= pulse_size:
                clean_graph.append(pulse[i:j])

            i = j + 1
            
    # Distance functions
    def dist(ax, ay, bx, by):
        return np.sqrt((ax - bx)**2 + (ay - by)**2)
    def dista(pair):
        return dist(pair[0][0], pair[0][1], pair[1][0], pair[1][1])

    # Clean pulses more
    cleaner_graph = list()
    smooth_graph = list()
    for pulse in clean_graph:
        x = [point[0] for point in pulse]
        y = [point[1] for point in pulse]
        savgol = savgol_filter(y, 17, 3)
        
        smooth_pulse = list(zip(x, savgol))
        smooth_graph.extend(smooth_pulse)

        cleaned_pulse = [pair[0] for pair in zip(pulse, smooth_pulse) if dista(pair) < dy_cutoff / 2]
        cleaner_graph.append(cleaned_pulse)

    return cleaner_graph

In [None]:
def rough_classify(datadir, outdir):
    path = f'{datadir}/**/*#'
    filenames = glob.glob(path, recursive=True)
    fig, axes = plt.subplots()
    
    def extract_graph(filename):
        data = extract_anabat(filename)
        raw = list(data)
        graph = [raw[0], raw[1]]
        graph = clean_graph(filename=filename, graph=graph)
        
        for i, pulse in enumerate(graph):
            x = [point[0] for point in pulse]
            y = [point[1] for point in pulse]
            
            plyft = polyfit(x=x, y=y, deg=1)

            if plyft[1] < 0:
                classification = '/echolocation/'
            else:
                classification = '/abnormal/'

            axes.axis('off')
            axes.scatter(x, y)
            filename_split = filename.rsplit(".", 1)[0].rsplit("\\", 1)[-1]
            save_path = f'{outdir}{classification}{filename_split}_{i}.png'
            fig.savefig(save_path, transparent=True, dpi=50)
            plt.cla()
            gc.collect()
    
    s = pd.Series(filenames)
    s.apply(lambda filename: extract_graph(filename))
    
# executing the below line will have the program read the entire "raw" folder, which could take some time
#rough_classify('../../data/raw', '../../data/sorted')