In [None]:
# set this to true first time the worksheet is run.

# need to install Matplotlib, scipy, scikit-learn, tqdm, sklearnex

INSTALL = False
if INSTALL:
    !pip install matplotlib
    !pip install scipy
    !pip install scikit-learn
    !pip install tqdm
    !pip install pip install scikit-learn-intelex
    

In [None]:

import json
import os
import math
import pickle
import sys
from dataclasses import dataclass  # for the lowest level class type
import h5py
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.signal import butter, filtfilt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm   # progress bar tool
from sklearnex import patch_sklearn
patch_sklearn()   
from datetime import datetime
import warnings
warnings.simplefilter(action='ignore', category=Warning)

In [None]:

SAMPLING_RATE = 100
BLOCK_SIZE = 500
TEST_DATA_HOURS = 48
BUTTER_ORDER = 4

# dataclass to represent a station
@dataclass
class EQStation:
    name: str
    dist_m: float
    ele: float
    lat: float
    lon: float

    def __str__(self):
        return f"Station: {self.name}\n" \
               f"Distance: {self.dist_m} m\n" \
               f"Elevation: {self.ele} m\n" \
               f"Latitude: {self.lat} deg\n" \
               f"Longitude: {self.lon} deg"


# an EQ dataset that represents one event and multiple stations
class EQDataSet:
    def __init__(self, filename: str = "", exclude_file: str = "", for_conversion: bool = False):
        self.filename = filename
        self.stations_data: list[EQStation] = []
        self.stations: list[str] = []
        self.magnitude: float = 0
        self.time: str = ""
        self.only_include: int = 999

        if not for_conversion:
            with h5py.File(self.filename, 'r') as file:
                self.magnitude = float(file.attrs['mag'])
                self.time = file.attrs['time']
                for item in file.items():
                    data_dict = dict(item[1].attrs)
                    for key, value in data_dict.items():
                        data_dict[key] = float(value)
                    self.stations.append(item[0])
                    self.stations_data.append(EQStation(name=item[0], **data_dict))
            # create a property self.time_scale that uses self.time
            # and the sampling rate of 100Hz to create a time scale
            # for the data
            # NOTE THIS HAS BEEN SUPERSEDED BY THE USE OF TIME SCALE
            # TAKEN FROM CALIBRATION EVENTS BEING FOUND
            end_time = pd.to_datetime(self.time)
            # uses Z because it doesn't matter which axis
            # they're all the same.
            self.time_scale = pd.date_range(end=end_time, periods=len(self.get_data(self.stations[0], 'Z')),
                                            freq='10ms')

    def data_csv(self):
        # dataframe with the data from all stations
        # and the time scale
        df = pd.DataFrame(columns=['Station', 'Location', 'Notes', 'Latitude', 'Longitude'])
        with h5py.File(self.filename, 'r') as file:
            for station in self.stations_data:
                df.loc[len(df)] = [station.name,'','',station.lat,station.lon]
        # write df as a csv to file
        df.to_csv("station_data.csv")
        return df

    def __str__(self):
        with h5py.File(self.filename, 'r') as file:
            ret_str = "-" * 30 + "\n"
            ret_str += "Filename: " + self.filename + "\n"
            ret_str += "-" * 30 + "\n"
            for station in self.stations_data:
                ret_str += f"{station}\n"
                ret_str += f"Time: {self.time}\n"
                ret_str += f"Magnitude: {self.magnitude}\n"
                ret_str += f"Num axes: {file[station.name].shape[0]}\n"
                ret_str += f"Num data-points / axis: {file[station.name].shape[1]}\n"
                ret_str += f"Num days: {file[station.name].shape[1]/(SAMPLING_RATE*60*60*24)}\n"
                ret_str += "#"*40 + "\n"
        return ret_str

    # get the data for a specific station and axis
    def get_data(self, station: str, axis_label: str, start_hour: float = 0,
                 num_hours: float = 0,
                 exclude_file: str = ""):
        axis_dict = {'N': 0, 'E': 1, 'Z': 2}
        axis = axis_dict[axis_label]
        start = int(start_hour*3600*SAMPLING_RATE)
        end = int((start_hour+num_hours)*3600*SAMPLING_RATE)
        with h5py.File(self.filename, 'r') as file:
            for s in self.stations_data:
                if s.name == station:
                    if num_hours:
                        raw_data = file[s.name][axis, start:end]
                    else:
                        raw_data = file[s.name][axis, start:]
                    if exclude_file:
                        raw_data = self.exclude_marked_up(raw_data, exclude_file=exclude_file)
                    return raw_data
        return None

    # THIS HAS BEEN SUPERSEDED BY THE AUTO-DETECTION OF CALIBRATION EVENTS
    @staticmethod
    def exclude_marked_up(raw_data, exclude_file: str):
        print("Excluding marked-up data events")
        # read csv into pandas dataframe
        df = pd.read_csv(exclude_file)  # , header=None)
        # read in event mark ups
        small_markups = df.iloc[4:, 6:8]
        # exclude rows of nans
        small_markups = small_markups.dropna()
        # exclude from data all the rows in small_markups
        padding = 500000
        # about 1.4 hours either side
        # of marked up event
        # initialise cleaned data with first chunk
        # up to first event marked up
        start, prev_end = small_markups.values[0]
        prev_end = int(prev_end) + padding
        start = max(int(start) - padding, 0)
        new_data = raw_data[:int(start)]
        for start, end in small_markups.values[1:]:
            start = max(int(start) - padding, 0)
            end = int(end) + padding
            # delete the data in the range
            try:
                new_data = np.append(new_data, raw_data[int(prev_end):int(start)])
                prev_end = end
            except IndexError:
                pass
        return new_data

    # plot the data for a specific station and axis for a given time period
    def plot_datum(self, station, axis_label, start_hour, num_hours):
        d = self.get_data(station, axis_label, start_hour, num_hours)
        plt.plot(d)
        plt.title(station + ' ' + axis_label)
        plt.xlabel('Sample Index (t*100)')
        plt.ylabel('Sample Value')
        plt.show()

    # plot the data for a specific station and all axes for a given time period
    def plot_data(self, station, start_hour, num_hours,
                  separate_axes=False):
        axes = ['N', 'E', 'Z']
        d = []
        for axis in axes:
            d.append(self.get_data(station, axis, start_hour, num_hours))
        x_data = 60*60*np.linspace(start_hour, start_hour+num_hours, len(d[0]))
        if not separate_axes:
            plt.plot(x_data,d[0], label='N')
            plt.plot(x_data,d[1], label='E')
            plt.plot(x_data,d[2], label='Z')
            plt.legend()
            # set the axis labels and title
            plt.xlabel('Time (s)')
            plt.ylabel('Speed (nm/s)')
            plt.title(f'Sensor Data for {station}')
        else:
            fig, axs = plt.subplots(3, 1, figsize=(8, 10))
            for i, ax in enumerate(axes):
                axs[i].plot(x_data, d[i])
                axs[i].set_xlabel('Time (s)')
                axs[i].set_ylabel('Speed (nm/s)')
                axs[i].set_title(f'Sensor Data for {axes[i]}({station})')
            # Adjusting the spacing between subplots
            plt.tight_layout()
        plt.grid(True)
        plt.show()

    # returns a numpy array of the concatenated data for all stations
    def concatenate_station_data(self, sensor, start_hour, num_hours,
                                 exclude_file: str = "",
                                 only_include: int=999):
        print("Concatenating data...")
        self.only_include = only_include
        data = []
        include_count = 0
        for station in self.stations:
            data.append(self.get_data(station, sensor, start_hour, num_hours, exclude_file))
            include_count += 1
            if include_count >= only_include:
                break
        return np.concatenate(data)


# this class is used to represent a dataset that is a combination of
# multiple stations' data but allows it to be pre-limited
# in time to avoid huge datasets.
# NOTE: I now use a sort of hack with this and a later class
class EQDataSetSingleSensorCombined(EQDataSet):
    def __init__(self, filename, axis='Z', start_hour=0, num_hours=0,
                 exclude_file: str="", only_include: int = 999, for_conversion=False):
        super().__init__(filename,
                         for_conversion=for_conversion)
        self.axis: str = axis
        self.start_hour: float = start_hour
        self.num_hours: float = num_hours
        # conversion is for turning json generated data into a EQDataSetSingleSensorCombined
        # it is done by a function in the JSON dataset class
        if not for_conversion:
            self.combined_data = self.concatenate_station_data(axis,
                                                               self.start_hour,
                                                               self.num_hours,
                                                               exclude_file=exclude_file,
                                                               only_include=only_include)


    def __str__(self):
        ret_str = "Filename: " + self.filename + "\n"
        ret_str += "Combination of " + str(len(self.stations)) + " stations' data\n"
        ret_str += "Station names: " + str(self.stations) + "\n"
        ret_str += f"Axis: {self.axis}\n"
        ret_str += f"Combined data shape: {self.combined_data.shape}\n"
        ret_str += f"Start hour: {self.start_hour}\n"
        ret_str += f"Num hours: {self.num_hours}\n"
        return ret_str


# THIS CLASS IS USED TO BUILD EXPERIMENT DATA FROM MULTIPLE
# EVENTS AND STATIONS USING A SPECIFIC JSON FORMAT.
# an EQ dataset JSON can have multiple event filenames
# but only one json file
class EQDataSetJSON:
    def __init__(self, filename_json: str = ""):
        self.filename_json = filename_json
        # Experiment definition file
        with open(self.filename_json, 'r') as file:
            self.json_data = json.load(file)
        self.event_filenames = [d['event_filename'] for d in self.json_data]
        self.stations = [d['station'] for d in self.json_data]
        self.exclude_files = [d['exclude_file'] for d in self.json_data]
        self.axis = [d['axis'] for d in self.json_data]
        self.start_hours = [d['start_hour'] for d in self.json_data]
        self.num_hours = [d['num_hours'] for d in self.json_data]
        self.data = np.array([])
        self.time_scale = []

    # pull all the raw data from an experiment file
    # Remove calibration events
    # and concatenate into one long data vector
    def get_data(self):
        for i, event_filename in enumerate(self.event_filenames):
            print("Reading in event file " + event_filename)
            with h5py.File(event_filename, 'r') as file:
                station = self.stations[i]
                exclude_file = self.exclude_files[i]
                axis = self.axis[i]
                print("Reading in axis " + axis)
                axis_dict = {'N': 0, 'E': 1, 'Z': 2}
                axis = axis_dict[axis]
                start_hour = self.start_hours[i]
                num_hours = self.num_hours[i]
                start = int(start_hour * 3600 * SAMPLING_RATE)
                end = int((start_hour + num_hours) * 3600 * SAMPLING_RATE)
                #print([f for f in file.items()])
                #print(list(file.attrs))
                for s in file.items():
                    if s[0] == station:
                        print("Reading in data for station " + station)
                        if num_hours:
                            raw_data = s[1][axis, start:end]
                        else:
                            raw_data = s[1][axis, start:]
                        #NOT USED
                        if exclude_file:
                            raw_data = EQDataSet.exclude_marked_up(raw_data, exclude_file=exclude_file)
                        #plt.plot(raw_data)
                        #plt.title(f"Raw data PRE for {station}")
                        #plt.show()
                        raw_data, time_scale = self.exclude_calibration_events(raw_data)
                        #plt.plot(raw_data)
                        #plt.ylabel("Amplitude")
                        #plt.xlabel("Sample number")
                        #plt.title(f"Raw data POST for {station}")
                        #plt.show()
                        #sys.exit()
            self.data = np.append(self.data, raw_data)
            print("New length of data: " + str(len(self.data)) + " samples")
            # this gives the approximate location of 9am JST
            # each day
            self.time_scale.append(time_scale)

    # Used to find calibration events for Hi-net data
    @staticmethod
    def compare_activity_lag(raw_data, i, lag, radius):
        return np.mean(raw_data[i-radius:i]) - np.mean(raw_data[i-radius-lag:i-lag])

    def seek_calibration_event_grid(self, raw_data):
        # given there are 31.25 days of data at 100Hz
        # the first calibration event should occur in the first 24
        # hours of data
        # so we can look for the first calibration event  in the first
        # 24*60*60*100 samples (1 day).
        # Based on this there should also be a calibration event
        # 5 seconds later
        # and also 12 hours later and 12 hours 5 seconds later.
        # We attempt to detect a calibration event by moving along
        # sample by sample and looking for a sudden jump in the data
        # that is greater than the previous jump by a factor of jump_factor (e.g. 2000)
        # We then check for a similar jump 5 seconds later
        # and if we find it we check for a similar jump 12 hours later
        # and if we find it We then check for a similar jump 5 seconds later after that.
        # If all of these are fulfilled, it is marked as a calibration event
        one_day = 24*60*60*SAMPLING_RATE
        five_seconds = 5*SAMPLING_RATE
        # How far the mean data has to jump to be a candidate
        # for a calibration event.
        # approximated by experiment
        jump_factor  = 4000 #4000 #2000 #50000
        # how often we check for a calibration event.
        # approximated by experiment
        step_size = 1000 #2500 #7500
        # how many samples we look back to calculate the mean.
        # approximated by experiment
        radius = 100 #100
        # only need to confirm the first calibration event
        # to have found them all.
        first_calibration_index = -1
        for i in range(step_size, one_day, step_size):
            # check for a sudden jump in the data
            if self.compare_activity_lag(raw_data,
                                         i, step_size,
                                         radius) > jump_factor:
                # we have found a sudden jump in the data
                # that is greater than the previous jump by a factor of 10
                # check for a similar jump 5 seconds later
                if self.compare_activity_lag(raw_data,
                                         i+five_seconds,
                                         step_size,
                                         radius) > jump_factor:

                    if self.compare_activity_lag(raw_data,
                                                 i+one_day, step_size,
                                                 radius) > jump_factor:
                        # we have found a sudden jump in the data
                        # that is greater than the previous jump by a factor of 10
                        # check for a similar jump 5 seconds later
                        if self.compare_activity_lag(raw_data,
                                                     i + one_day+five_seconds,
                                                     step_size,
                                                     radius) > jump_factor:

                            first_calibration_index = i+radius
                            print(f"Found calibration event at index {i}")
                            break
        return first_calibration_index

    def exclude_calibration_events(self, raw_data):
        print("Seeking calibration events...")
        first_calibration_index = self.seek_calibration_event_grid(raw_data)
        if first_calibration_index == -1:
            return raw_data, None
        # generate a mask for the 9am calibration events
        # which will be seperated by 24*60*60*100 samples
        calib_event_1_indices = np.arange(first_calibration_index,
                                          len(raw_data), 24*60*60*SAMPLING_RATE)
        print(f"{calib_event_1_indices=}")
        exclude_radius = 15000 #15000 #15000
        indices_to_delete = []
        for cei in calib_event_1_indices:
            # delete the vector elements from cei - exclude_radius to cei + exclude_radius
            indices_to_delete += list(range(cei-exclude_radius, cei+exclude_radius))
        total_time_removed = len(indices_to_delete)/SAMPLING_RATE
        print(f"Total time removed with calib events: {total_time_removed} seconds")
        raw_data = np.delete(raw_data, indices_to_delete)
        time_scale = calib_event_1_indices-exclude_radius  # gives the approx 9am each day
        time_scale= np.append(np.array([0]), time_scale)
        time_scale = np.append(time_scale, np.array(len(raw_data)))
        print("Calibration events deleted")
        return raw_data, time_scale

    # This creates a sort of Franken
    def generate_EQDataSetCombined(self, axis):
        eqdsc = EQDataSetSingleSensorCombined(filename=self.filename_json,
                                              axis=axis,
                                              for_conversion=True)
        # the two filenames refer to different types but should be ok
        # other elements can be left alone hopefully
        eqdsc.stations = self.stations  # stations from the json
        # need to copy the concatenated data
        # from the EQDataSetJSON
        eqdsc.combined_data = self.data
        eqdsc.time_scale = self.time_scale # list but should be ok
        eqdsc.start_hour = self.start_hours # list but should be ok
        eqdsc.num_hours = self.num_hours # list but should be ok
        eqdsc.axis = axis  # list but should be ok
        # eqdsc.magnitude  doesn't exist in json version
        # eqdsc.time doesn't exist in json version
        # eqdsc.only_include doesn't exist in json version
        return eqdsc


class EQDataSetProcessor:
    def __init__(self, dataset: EQDataSetSingleSensorCombined, rebuild_features=True):
        # the dataset contains a list of station data etc
        self.dataset: EQDataSetSingleSensorCombined = dataset
        self.rebuild_features: bool = rebuild_features
        # short hand for the data set
        self.ds: np.ndarray = self.dataset.combined_data
        # data set after being high pass filtered (band pass?)
        self.filtered_data: np.ndarray = np.array([])
        # data after being chunked in 500s pre-feature analysis
        self.chunked_data: np.ndarray = np.array([])
        # data after each of the 500 samples being turned
        # into a single set of 7 features
        self.features_df: pd.DataFrame = pd.DataFrame()

    def __str__(self):
        ret_str = "-" * 30 + "\n"
        ret_str += "Dataset Processor\n"
        ret_str += "-" * 30 + "\n"
        ret_str += f"Stations combined: {self.dataset.stations[:self.dataset.only_include]}\n"
        ret_str += f"Axis: {self.dataset.axis}\n"
        ret_str += f"Start: {self.dataset.start_hour}\n"
        ret_str += f"Num Hours: {self.dataset.num_hours}\n"
        ret_str += f"Dataset: {self.ds.shape}\n"
        ret_str += f"Filtered: {self.filtered_data.shape}\n"
        ret_str += f"Chunked: {self.chunked_data.shape}\n"
        ret_str += f"Features: {self.features_df.shape}\n"
        return ret_str

    # high pass filter
    # has to be static so that it can be used in the map function
    @staticmethod
    def HPF(data: np.ndarray = None):
        cutoff_frequency: float = 1  # Hz
        # this function implements a one hz high pass filter
        #b, a = butter(10, cutoff_frequency, btype='high',
                      #fs=SAMPLING_RATE)
        #b, a = butter(BUTTER_ORDER, [1,SAMPLING_RATE//2 - 1], btype='band',
        #              fs=SAMPLING_RATE)
        b, a = butter(BUTTER_ORDER, 1, btype='high',
                                    fs=SAMPLING_RATE)
        # The filtfilt function performs this bidirectional filtering by
        # applying the filter coefficients both forwards and backwards.
        # This process effectively eliminates the phase distortion
        # introduced by the filter and provides a more accurate
        # representation of the filtered signal in the time domain.
        filtered_data = filtfilt(b, a, data)
        #if random.random() < 0.01:
        #    print(".", end="")
        return filtered_data

    # non-static high pass filter
    def apply_HPF(self, rebuild_filtered=True):
        if not rebuild_filtered:
            print("Loading in a presaved filtered data")
            with open('filtered_data.pkl', 'rb') as f:
                self.filtered_data = pickle.load(f)
            return
        cutoff_frequency: float = 1  # Hz
        # this function implements a one hz high pass filter
        # Second order filter low pass Butterworth
        # filter will have moderate roll-off (coz order 2)
        # designed to cuttoff stuff above the Nyquist freq
        # which is not well sampled.
        # Designing a second-order Butterworth high-pass filter
        # b and a are the filter parameters
        # b, a = butter(10, cutoff_frequency, btype='high',
        # fs=SAMPLING_RATE)
        b, a = butter(BUTTER_ORDER, [1, SAMPLING_RATE // 2 - 1], btype='band',
                      fs=SAMPLING_RATE)
        # The filtfilt function performs this bidirectional filtering by
        # applying the filter coefficients both forwards and backwards.
        # This process effectively eliminates the phase distortion
        # introduced by the filter and provides a more accurate
        # representation of the filtered signal in the time domain.
        self.filtered_data = filtfilt(b, a, self.ds)
        with open('filtered_data.pkl', 'wb') as f:
            pickle.dump(self.filtered_data, f)


    def chunk_data_filtered(self):
        # split the data into 500 sample chunks
        # this is to make the data more manageable
        # for the FFT
        print("filtered data size is ", self.filtered_data.shape)
        chunk_size = BLOCK_SIZE
        # check if the data is a multiple of the chunk size
        # if not, then drop the last few samples
        if len(self.filtered_data) % chunk_size != 0:
            self.filtered_data = self.filtered_data[:-(len(self.filtered_data) % chunk_size)]
        num_chunks = len(self.filtered_data) // chunk_size
        print("Chunking data")
        chunks = np.array_split(self.filtered_data, num_chunks)
        # also split self.ds into chunks of the same size, leaving
        # out any remainder

        # convert to numpy array
        chunks = np.array(chunks)

        # convert back to numpy array
        # but drop the last chunk assuming it is not the right size
        #for c in chunks:
            #print(f"chunk shape is {c.shape}")
        print("BLOCK_SIZE is ", BLOCK_SIZE)
        print("num chunks is ", num_chunks)
        self.chunked_data = chunks

    def p_integral_squared(self, X):
        integral_squared = np.trapz(X ** 2, axis=1)
        return integral_squared

    def p_max_spectral_amp(self, Fx):
        max_spectral_amp = np.max(np.abs(Fx), axis=1)  # this will be an array of size (N,)
        return max_spectral_amp

    def p_freq_at_max_spectral_amp(self, Fx, f):
        print("freq at max spectral amp")
        freq_at_max_spectral_amp = f[np.argmax(Fx, axis=1)]
        return freq_at_max_spectral_amp

    def p_centre_freq(self, Fx, f):
        centre_freq = np.sum(f[np.newaxis, :] * Fx, axis=1) / np.sum(Fx, axis=1)
        signal_bandwidth = np.sqrt(
            np.sum((f[np.newaxis, :] - centre_freq[:, np.newaxis]) ** 2 * Fx, axis=1) / np.sum(Fx, axis=1)
        )
        return centre_freq, signal_bandwidth

    def p_zero_upcrossing_rate(self, Fx, f, w):
        zero_upcrossing_rate = np.sqrt(
            np.sum(w[np.newaxis, :] ** 2 * Fx ** 2, axis=1) / np.sum(Fx ** 2, axis=1)
        )
        return zero_upcrossing_rate

    def p_rate_spectral_peaks(self, Fx, f, w):
        rate_spectral_peaks = np.sqrt(
            np.sum(w[np.newaxis, :] ** 4 * Fx ** 2, axis=1) / np.sum(w[np.newaxis, :] ** 2 * Fx ** 2, axis=1)
        )
        return rate_spectral_peaks

    def compute_features(self):
        # Compute FFT for each row
        print("Computing features matrix")
        X = self.chunked_data
        print("FFT")
        Fx = np.abs(np.fft.rfft(X, axis=1))
        Fx = (2.0 / X.shape[1]) * Fx  # normalization
        f = np.fft.rfftfreq(X.shape[1],
                            1 / SAMPLING_RATE)  # Assuming self.dataset.sampling_rate is available
        w = 2 * np.pi * f

        # Compute features
        #integral_squared = delta_t * np.sum(X**2, axis=1)
        print("integral squared")
        integral_squared = np.trapz(X ** 2, axis=1)
        print("max spectral amp")
        max_spectral_amp = np.max(np.abs(Fx), axis=1)  # this will be an array of size (N,)
        print("freq at max spectral amp")
        freq_at_max_spectral_amp = f[np.argmax(Fx, axis=1)]
        print("centre freq")
        centre_freq = np.sum(f[np.newaxis, :] * Fx, axis=1) / np.sum(Fx, axis=1)
        signal_bandwidth = np.sqrt(
            np.sum((f[np.newaxis, :] - centre_freq[:, np.newaxis]) ** 2 * Fx, axis=1) / np.sum(Fx, axis=1)
        )
        print("zero upcrossing rate")
        zero_upcrossing_rate = np.sqrt(
            np.sum(w[np.newaxis, :] ** 2 * Fx ** 2, axis=1) / np.sum(Fx ** 2, axis=1)
        )
        print("rate spectral peaks")
        rate_spectral_peaks = np.sqrt(
            np.sum(w[np.newaxis, :] ** 4 * Fx ** 2, axis=1) / np.sum(w[np.newaxis, :] ** 2 * Fx ** 2, axis=1)
        )


        # Create DataFrame
        features_df = pd.DataFrame({
            'integral_squared': integral_squared,
            'max_spectral_amp': max_spectral_amp,
            'freq_at_max_spectral_amp': freq_at_max_spectral_amp,
            'centre_freq': centre_freq,
            'signal_bandwidth': signal_bandwidth,
            'zero_upcrossing_rate': zero_upcrossing_rate,
            'rate_spectral_peaks': rate_spectral_peaks
        })
        print(f"features_df shape is {features_df.shape}")
        self.features_df = features_df

    # core method of class
    # it chunks the data into 500 sample chunks
    # then applies the high pass filter to each chunk
    # then computes the FFT for each chunk
    # then computes the features for each chunk
    # then returns a dataframe of the features
    def compute_features_matrix(self, chunking=True, disable_filter=False, rebuild_hpf=True):
        # only calculate features if not already saved to file
        if self.rebuild_features or not os.path.exists('features.pkl'):
            # apply hpf to whole dataset
            if not disable_filter:
                print("Applying high pass filter to whole dataset")
                self.apply_HPF(rebuild_hpf)

            if chunking:
                self.chunk_data_filtered()
                print(f"X shape is {self.chunked_data.shape}")
            else:
                X = self.ds
                print(f"X Unchunked HPF shape is {X.shape}")
                self.chunked_data = X.reshape(1,BLOCK_SIZE)
                print(f"filtered ds rechunked HPF shape is {self.filtered_data.shape}")
            if disable_filter:
                self.filtered_data = self.ds
                self.chunk_data_filtered()
            print(f"filtered ds HPF shape is {self.filtered_data.shape}")
            self.compute_features() # note this works on filtered_data
            # write features object to file using pickle
            with open('features.pkl', 'wb') as f:
                pickle.dump(self.features_df, f)
        else:
            # load features from file
            print("Loading features from file")
            with open('features.pkl', 'rb') as f:
                self.features_df = pickle.load(f)

    def plot_features(self, percentage = 1.0):
        print("Plotting features")
        # subplot the features in a graph above a subplot
        # of the raw data so they are lined up
        # plot the raw data
        fig, axs = plt.subplots(
            len(self.features_df.columns)+1, 1,
            figsize=(15, 10))
        axs[0].plot(self.ds[:int(percentage*len(self.ds))])
        axs[0].set_title(f"Raw data for {self.dataset.stations[:self.dataset.only_include]} {self.dataset.axis} ")
        # create a dataframe of the features
        # where each feature value is repeated BLOCK_SIZE times
        # to line up with the raw data
        lined_up_features_df = pd.DataFrame()
        features_clipped = self.features_df[:int(percentage*len(self.features_df))]
        for col in features_clipped.columns:
            lined_up_features_df[col] = np.repeat(features_clipped[col].values, BLOCK_SIZE)

        # plot the features each in their own subplot
        for f in range(len(self.features_df.columns)):
            axs[f+1].plot(
                lined_up_features_df[
                    self.features_df.columns[f]],
                label=self.features_df.columns[f])
            axs[f+1].set_title(f"{self.features_df.columns[f]}")
        plt.tight_layout()
        plt.show()


# This class performs dimension reduction on the features
# and searches for correct cluster count on train data
# then clusters on the test data
class EQDataSetClusterTrain:
    def __init__(self, dataset_processed: EQDataSetProcessor, filename: str = ''):
        if filename != '':
            self.dataset_processed: EQDataSetProcessor = pickle.load(open(filename, 'rb'))
        self.dataset_processed: EQDataSetProcessor = dataset_processed
        self.test_features_df: pd.DataFrame = pd.DataFrame()
        self.train_features_df: pd.DataFrame = pd.DataFrame()
        self.standardised_train_features_df: pd.DataFrame = pd.DataFrame()
        self.standardised_test_features_df: pd.DataFrame = pd.DataFrame()
        self.test_size: float = 0.1
        self.features_df = train_test_split(
            self.dataset_processed.features_df,
            test_size=self.test_size, shuffle=True) #, random_state=7)
        self.features_df = {'train': self.features_df[0],
                            'test': self.features_df[1]}
        print(f"{self.features_df=}")
        self.standardised_features_df: dict[str, pd.DataFrame] = {'train': pd.DataFrame(),
                                                'test': pd.DataFrame()}
        self.pca_df: dict[str, pd.DataFrame] = {'train': pd.DataFrame(),
                                                'test': pd.DataFrame()}
        self.explained_variance: dict[str, np.ndarray] = {'train': np.array([]),
                                                'test': np.array([])}
        self.features_pca: dict[str, PCA] = {'train':
                                                 PCA(whiten=True),
                                             'test':
                                                 PCA(whiten=True)
                                        }
        self.max_clusters: int = 10
        self.train_clusters: KMeans = KMeans(n_clusters=self.max_clusters) #,
                                            #random_state=7)
        self.n_refs: int = 100  # number of references for gap score
        self.gaps: list[np.ndarray] = list(np.array([]))
        self.gaps_sd: list[np.ndarray] = list(np.array([]))
        self.silhouette_scores: list[float] = list([])
        self.gap = np.array([])
        self.gap_sd = np.array([])
        self.inertias = np.array([])
        self.inertias_sd = np.array([])
        self.silhouettes = np.array([])
        self.silhouettes_sd = np.array([])

    def __str__(self):
        # use ret_str to build up the string to return
        # all properties of the class are added to the string
        ret_str = f"EQDataSetClusterer object\n"
        ret_str += f"dataset_processed: {self.dataset_processed}\n"
        ret_str += f"test_features_df: {self.test_features_df}\n"
        ret_str += f"train_features_df: {self.train_features_df}\n"
        ret_str += f"standardised_train_features_df: {self.standardised_train_features_df}\n"
        ret_str += f"standardised_test_features_df: {self.standardised_test_features_df}\n"
        ret_str += f"test_size: {self.test_size}\n"
        ret_str += f"features_df: {self.features_df}\n"
        ret_str += f"standardised_features_df: {self.standardised_features_df}\n"
        ret_str += f"pca_df: {self.pca_df}\n"
        ret_str += f"explained_variance: {self.explained_variance}\n"
        #ret_str += f"n_components: {self.n_components}\n"
        ret_str += f"features_pca: {self.features_pca}\n"
        ret_str += f"max_clusters: {self.max_clusters}\n"
        ret_str += f"train_clusters: {self.train_clusters}\n"
        ret_str += f"n_refs: {self.n_refs}\n"
        ret_str += f"gaps: {self.gaps}\n"
        ret_str += f"gaps_sd: {self.gaps_sd}\n"
        ret_str += f"silhouette_scores: {self.silhouette_scores}\n"
        ret_str += f"gap: {self.gap}\n"
        ret_str += f"gap_sd: {self.gap_sd}\n"
        return ret_str


    # for each feature in the train and test df, subtract its mean and divide by its standard deviation
    def standardise_features(self):
        print("Standardising features")
        for d in ['train', 'test']:
            self.standardised_features_df[d] = (self.features_df[d] -
                                        self.features_df[d].mean()
                                      ) / self.features_df[d].std()
    # apply incremental PCA across the features
    # and display percentage of variance explained
    # by each component
    def whitened_pca(self, trim_components: bool = True):
        columns = [f"PCA{i}" for i in range(self.standardised_features_df['train'].shape[1])]
        for d in ['train', 'test']:
            print(f"Whitening PCA for {d} data")
            pca_data = self.features_pca[d].fit_transform(self.standardised_features_df[d])
            self.explained_variance[d] = self.features_pca[d].explained_variance_ratio_
            print(f"{d} explained variance is {self.explained_variance[d]}")
           
            self.pca_df[d] = pd.DataFrame(pca_data, columns=columns, index=self.standardised_features_df[d].index)

        if trim_components:
            # go through pca until >95% of variance is explained
            # and remove any components that don't explain much variance
            var_sum = 0
            n_components = 1
            for v in self.explained_variance['train']:
                var_sum += v
                print(f"% variance explained by {n_components} components is {var_sum}")
                if round(var_sum, 2) < 0.99: #0.9 25:
                    n_components += 1
                else:
                    break
            self.pca_df['train'] = self.pca_df['train'].iloc[:, :n_components]
            self.pca_df['test'] = self.pca_df['test'].iloc[:, :n_components]

    # new code as of 13 aug 2023
    def gap_statistic(self, data, kmeans_object):
        # Calculate the actual clustering's sum of squared distances
        centroids = kmeans_object.cluster_centers_
        # centroids.shape is (n_clusters, n_features) - e.g. (2, 6)
        # for 2 clusters and the 6 PCA features
        """The inertia_ attribute of a KMeans object gives the sum of squared 
        distances of samples to their closest cluster center, which is 
        exactly what you're trying to compute.
        """
        ssd_actual = kmeans_object.inertia_
        n_refs = 100
        ssd_refs = []
        d = np.array(data)
        data_mean = np.mean(d, axis=0).reshape(1, -1)
        data_std = np.std(d, axis=0).reshape(1, -1)
        for _ in range(n_refs):
            random_data = (np.random.rand(*data.shape, ))
            # scale random_data so it goes between max and min of data columns
            random_data = random_data * (np.max(d, axis=0) - np.min(d, axis=0)) + np.min(d, axis=0)
            km_ref = KMeans(n_clusters=len(centroids), n_init=1,
                            init='random').fit(random_data)
            # note km_ref.inertia_ is the sum of squared distances of fake
            # samples to their closest cluster center
            # which is exactly what you're trying to compute in the gap statistic
            ssd_refs.append(km_ref.inertia_)
        # Calculate the Gap statistic
        gap_stat = np.mean(np.log(ssd_refs) - np.log(ssd_actual))
        # Calculate the standard deviation of the Gap statistic
        gaps_sd = np.std(np.log(ssd_refs)- np.log(ssd_actual))
        ss = silhouette_score(data,
                kmeans_object.labels_)
        self.silhouette_scores.append(ss)
        return gap_stat, gaps_sd

    def run_training(self):
        #print(self)
        # display available memory
        """print(f"Available memory pre delete: {psutil.virtual_memory()}")
        self.dataset_processed.dataset.combined_data = np.array([])
        self.dataset_processed.filtered_data = np.array([])
        self.dataset_processed.chunked_data = np.array([])
        gc.collect()
        print(f"Available memory post delete: {psutil.virtual_memory()}")"""
        self.standardise_features()
        self.whitened_pca()
        # plot the pca correlation matrix
        corr = self.pca_df['train'].corr()
        plt.matshow(corr)
        plt.title('PCA covariance matrix')
        # colormap legend
        plt.colorbar()
        plt.show()

        PROCESS_REPEAT_L0_N = 3  # 3
        SAMPLE_REPEAT_L1_N = 52 #52 #5 #52 #80 #80//2 #<- for 4 stations #80//4 # <- for one station  # 500 in original
        SAMPLE_SIZE =  7500 #15000  # 15000
        MAX_K = 10 #10
        USE_TQDM = False
        if USE_TQDM:
            process_repeat_range = tqdm(range(PROCESS_REPEAT_L0_N), desc="process_repeat",
                                        leave=True, position=0)
        else:
            process_repeat_range = range(PROCESS_REPEAT_L0_N)
        for process_repeat in process_repeat_range: #tqdm(range(PROCESS_REPEAT_L0_N), desc="process_repeat"): #range(PROCESS_REPEAT_L0_N):
            if not USE_TQDM:
                print("process_repeat:", process_repeat)
            # create a dictionary of lists to store the inertias for each k
            all_inertias = {str(k): [] for k in range(2, MAX_K + 1)}
            all_gaps = {str(k): [] for k in range(2, MAX_K + 1)}
            all_gaps_sd = {str(k): [] for k in range(2, MAX_K + 1)}
            all_silhouettes = {str(k): [] for k in range(2, MAX_K + 1)}
            if USE_TQDM:
                loop_obj_sample = tqdm(range(SAMPLE_REPEAT_L1_N), desc="\tsample_repeat",
                                       leave=True, position=1)
            else:
                loop_obj_sample = range(SAMPLE_REPEAT_L1_N)
            for sample_repeat in loop_obj_sample: #tqdm(range(SAMPLE_REPEAT_L1_N), desc="sample_repeat"): #range(SAMPLE_REPEAT_L1_N):
                if not USE_TQDM:
                    print("\tsample_repeat:", sample_repeat)
                # get a random sample of 15,000 from data
                sample_df = self.pca_df['train'].sample(n=SAMPLE_SIZE)
                if USE_TQDM:
                    loop_obj_kmeans = tqdm(range(2, MAX_K + 1),
                                           desc="\t\tkmeans",
                                           leave=True, position=2)
                else:
                    loop_obj_kmeans = range(2, MAX_K + 1)
                for k in loop_obj_kmeans:
                    if not USE_TQDM:
                        print("\t\tk:", k)
                    run_kmeans = KMeans(
                        n_clusters=k, n_init='auto',
                        init='k-means++').fit(sample_df)
                    gap_stat, gap_sd = self.gap_statistic(sample_df, run_kmeans)
                    all_gaps[str(k)].append(gap_stat)
                    all_gaps_sd[str(k)].append(gap_sd)
                    # calculate the gap statistic using sample_store[min_index]
                    # and store in all_gaps[str(k)]
                    all_inertias[str(k)].append(run_kmeans.inertia_)
                    all_silhouettes[str(k)].append(silhouette_score(sample_df,
                        run_kmeans.labels_))
                   
            # Calculate the mean and standard deviation of the inertias for each k
            print("Total min values collected from samples:", len(all_inertias['2']))
            mean_inertias = {k: round(np.mean(all_inertias[k]), 0) for k in all_inertias}
            std_inertias = {k: round(np.std(all_inertias[k]), 0) for k in all_inertias}
            mean_silhouettes = {k: np.mean(all_silhouettes[k]) for k in all_silhouettes}
            std_silhouettes = {k: np.std(all_silhouettes[k]) for k in all_silhouettes}
            mean_gap = {k: np.mean(all_gaps[k]) for k in all_gaps}
            # note, the std of the gap statistic is not
            # calculated, but is stored as the std of the reference
            # for the error bars.
            std_gap = {k: np.mean(all_gaps_sd[k]) for k in all_gaps_sd}
            self.gap = np.append(self.gap,mean_gap)
            self.gap_sd = np.append(self.gap_sd, std_gap)
            self.inertias = np.append(self.inertias, mean_inertias)
            self.inertias_sd = np.append(self.inertias_sd, std_inertias)
            self.silhouettes = np.append(self.silhouettes, mean_silhouettes)
            self.silhouettes_sd = np.append(self.silhouettes_sd, std_silhouettes)
            self.plot_train_cluster_results(max_k=MAX_K, title=f"Repeat {process_repeat+1}")


    def plot_train_cluster_results(self, max_k, title='Mean Gap stat plot', process_index: int=-1):
        # plot the gap statistic across all training
        # by default displays the most recent
        # process_index = -1
        # otherwise specify the index of the process to plot (usually 0 to 2)
        gap = list(self.gap[process_index].values())
        sd = list(self.gap_sd[process_index].values())
        #sd = [s*1.96 for s in sd]
        inertias = np.array(list(self.inertias[process_index].values()))
        silhouettes = np.array(list(self.silhouettes[process_index].values()))

        clusters = range(2, max_k + 1)
        print(clusters)
        print(gap)
        #fig, axs = plt.subplots(2, 2)

        plt.errorbar(clusters, gap, yerr=sd)
        # plt.plot(clusters, gap, label='gap statistic')
        plt.xlabel('Number of clusters')
        plt.ylabel('Gap statistic')
        plt.title(f"{title}: Gap statistic")
        plt.grid()
        plt.show()

        # plot the derivative of the gapstatistic

        print("np.diff(gap): ",np.diff(gap)/sd[-1])
        #plt.plot(clusters[1:], np.diff(gap) / sd[:-1], color='r', label='derivative of gap statistic')
        plt.plot(clusters[1:], np.diff(gap)/sd[-1], color='r', label='derivative of gap statistic')
        plt.xlabel('Number of clusters')
        plt.ylabel('Derivative of Gap statistic')
        plt.title(f"{title}: Deriv. gap statistic")
        plt.grid()
        plt.show()

        # plot inertia scores
        plt.plot(clusters, inertias, color='g', label='inertia')
        plt.ylabel('Inertia')
        plt.xlabel('Number of clusters')
        plt.title(f"{title}: Inertia")
        plt.grid()
        plt.show()
        # plot silhouette scores

        plt.plot(clusters, silhouettes, color='g', label='inertia')
        plt.xlabel('Number of clusters')
        plt.ylabel('Silhouette')
        plt.title(f"{title}: Silhouette")
        plt.grid()
        plt.show()

    def load_training(self):
        # load training and test data from file
        with open('train_features.pkl', 'rb') as f:
            train_features = pickle.load(f)
        with open('test_features.pkl', 'rb') as f:
            test_features = pickle.load(f)


# This class uses the results of EQDataSetClusterTrain to cluster the Test dataset
class EQDataSetClusterTest:
    def __init__(self, training_structure: EQDataSetClusterTrain, num_clusters=5):
        # this is the training structure that would've been used to decide on
        # the number of clusters
        self.training_structure: EQDataSetClusterTrain = training_structure
        # this is the number of clusters that was decided on
        self.num_clusters: int = num_clusters
        self.sample_size: int = 8000 # not used
        self.final_kmeans = None
        self.n_init: str = 'auto'
        self.max_no_improvement: int = 250 #500
        # The results of compute_initial_centroids are stored here:
        self.initial_centroids = None
        self.labels = None
        self.raw_data_with_clusters: pd.DataFrame = pd.DataFrame()
        self.first_raw_row_index: int = 0
        self.all_clusters = None
        # these are used to enable clean reversing
        self.labels_train = None

    # this uses the algorithm from the Johnson paper
    def compute_initial_centroids(self):
        print("Computing initial centroids")
        data = self.training_structure.pca_df['test']
        n_clusters = self.num_clusters
        #print(f"len(data) = {len(data)}")
        print(f"sample_size = {self.sample_size} of {len(data)}")
        best_inertia = float('inf')
        no_improvement_count: int = 0
        best_centroids = None

        while no_improvement_count < self.max_no_improvement:
            if no_improvement_count % 50 == 0:
                print(f"no_improvement_count = {no_improvement_count}")
            # Sample a batch from the data
            indices = np.random.choice(data.shape[0],
                                       self.sample_size, replace=False)
            # print(f"indices = {indices}")
            sample = data.iloc[indices]
            # print(f"batch.shape = {batch.shape}")
            # Perform KMeans clustering on the batch
            kmeans_results = KMeans(n_clusters=n_clusters,
                                    init='random',
                                    n_init=self.n_init).fit(sample)
            # Check for improvement
            if kmeans_results.inertia_ < best_inertia:

                best_inertia = kmeans_results.inertia_
                best_centroids = kmeans_results.cluster_centers_
                print(f"New best inertia {best_inertia} and centroids {best_centroids}")
                no_improvement_count = 0
            else:
                no_improvement_count += 1
        self.initial_centroids = best_centroids

    def cluster_test(self):
        # final model is calculated over the test data using 5 clusters.
        # initial centroids are used to seed the k-means model for optimisation
        # using all evaluation data and apply a cluster label to each.

        self.sample_size: int = 25000 #8000
        self.n_init = 100 #80 #500 #80
        self.max_no_improvement: int = 500 #100 #80 #250 #    500 #250  # 500

        REBUILD_CENTROIDS = True
        # check if initial centroids have already been computed
        if not REBUILD_CENTROIDS and os.path.exists('initial_centroids.pkl'):
            # load initial centroids from file
            with open('initial_centroids.pkl', 'rb') as f:
                self.initial_centroids = pickle.load(f)
        else:
            self.compute_initial_centroids()
            # pickle initial centroids
            with open('initial_centroids.pkl', 'wb') as f:
                pickle.dump(self.initial_centroids, f)

        # Use the initial centroids to seed the KMeans model and fit on all data
        print("Running full k-means test")
        self.final_kmeans = KMeans(n_clusters=self.num_clusters,
                              init=self.initial_centroids).fit(
                                self.training_structure.pca_df['test'])
        # Assign cluster labels to each data point
        self.labels = self.final_kmeans.labels_

        # Count occurrences of each label
        label_counts = np.bincount(self.labels)

        # Compute percentages
        label_percentages = (label_counts / len(self.labels)) * 100

        # Output percentages for each label
        for i, percentage in enumerate(label_percentages):
            print(f"Label {i}: {percentage:.2f}%")

        print("Running full k-means train")
        self.final_kmeans_train = KMeans(n_clusters=self.num_clusters,
                                   init=self.initial_centroids).fit(
            self.training_structure.pca_df['train'])
        # Assign cluster labels to each data point
        self.labels_train = self.final_kmeans_train.labels_


    def get_original_cluster_feature_rows(self):
        # map the original test feature rows to the cluster labels
        # this is the data that was used to train the model
        # after it was chunked but before it was standardized and reduced to 6 dimensions
        clusters = []
        print(f"{self.training_structure.features_df['train']=}")
        print(f"{self.training_structure.features_df['test']=}")
        for i in range(self.num_clusters):
            clusters.append(self.training_structure.features_df['test'][self.labels == i])
        # create a train clusters dataframe which consists of the original train feature row
        # indices and a column of "fake" clusters labels which are all -1
        train_clusters = pd.DataFrame(
            {'orig_index': self.training_structure.features_df['train'].index,
             'cluster': self.labels_train})

        # create a test clusters dataframe which consists of the original test features row
        # indices and a column of the relevant cluster labels calculated in the loop above
        test_clusters = pd.DataFrame(
            {'orig_index': self.training_structure.features_df['test'].index,
             'cluster': self.labels})
        #print(self.training_structure.features_df['test'])
        #print(self.training_structure.features_df['train'])
        
        # now combine the above two dataframes into a single dataframe
        self.all_clusters = pd.concat([train_clusters, test_clusters])

        # now sort the dataframe by the index column
        self.all_clusters.sort_values(by='orig_index', inplace=True)
        print(f"{self.all_clusters=}")
        
        self.training_structure.features_df['test']['cluster'] = test_clusters['cluster'].values
        self.training_structure.features_df['train']['cluster'] = train_clusters['cluster'].values


    # function to use the self.training_structure.features_df['test']['cluster'] column
    # to map the original raw test data rows to the cluster labels
    def get_original_cluster_raw_rows_no_shuffle(self):
        # get the original raw data from the dataset
        # which was turned into the features that become the test feature data
        # assume unshuffled data
        # get first row index of test feature data
        first_test_row_index = self.training_structure.features_df['test'].index[0]
        #print(first_test_row_index)
        # convert it to the first row index of the raw data, given that the raw
        # data is unshuffled and BLOCK_SIZE rows are used to create each feature row
        self.first_raw_row_index = first_test_row_index * BLOCK_SIZE
        print(f"{self.first_raw_row_index=}")
        raw_test_data = self.training_structure.dataset_processed.dataset.combined_data[self.first_raw_row_index:]
        print(raw_test_data)
        # create a dataframe of the features
        # where each feature value is repeated BLOCK_SIZE times
        # to line up with the raw data
        lined_up_raw_data_df = pd.DataFrame()
        lined_up_raw_data_df['orig_data'] = raw_test_data
        for col in self.training_structure.features_df['test'].columns:
            # if the remaining rows are a multiple of BLOCK_SIZE
            if len(raw_test_data) % BLOCK_SIZE == 0:
                lined_up_raw_data_df[col] = np.repeat(
                    self.training_structure.features_df['test'][col].values, BLOCK_SIZE)
            else:
                even_rows = np.repeat(
                    self.training_structure.features_df['test'][col].values, BLOCK_SIZE)
                leftover_rows = np.array([0] * (len(raw_test_data) % BLOCK_SIZE))
                
                lined_up_raw_data_df[col] = np.concatenate((even_rows, leftover_rows))
                # fill in the last part (which is less than BLOCK_SIZE) with the last value
        # add the cluster labels to the dataframe
        self.raw_data_with_clusters = lined_up_raw_data_df


# This class uses the results of EQDataSetClusterTest to plot graphs
class EQDataSetClusterTestPlot:
    def __init__(self, test_structure: EQDataSetClusterTest):
        self.tst_st: EQDataSetClusterTest = test_structure

    def plot_features_and_clusters_non_shuffled(self, proportion=1.0):
        # subplot the features in a graph above a subplot
        # of the raw data so they are lined up
        # plot the raw data
        end_raw = int(len(self.tst_st.raw_data_with_clusters) * proportion)
        orig_features = self.tst_st.training_structure.dataset_processed.features_df
        orig_columns = orig_features.columns
        num_features = len(orig_features.columns)
        clusters = self.tst_st.raw_data_with_clusters['cluster'].values
        fig, axs = plt.subplots(
            num_features+2, 1,
            figsize=(15, 10))
        orig_data = self.tst_st.training_structure.dataset_processed.dataset.combined_data[self.tst_st.first_raw_row_index:]
        axs[0].plot(orig_data[:end_raw])
        stations = self.tst_st.training_structure.dataset_processed.dataset.stations[:self.tst_st.training_structure.dataset_processed.dataset.only_include]
        axs[0].set_title(
            f"Raw data for {stations} {self.tst_st.training_structure.dataset_processed.dataset.axis} ")
            #f"Raw data for {self.tst_st.training_structure.dataset_processed.dataset.stations} {self.tst_st.training_structure.dataset_processed.dataset.axis} ")
        axs[1].plot(clusters[:end_raw]+1)
        axs[1].set_title(
            f"Clusters for {self.tst_st.training_structure.dataset_processed.dataset.stations} {self.tst_st.training_structure.dataset_processed.dataset.axis} ")
        # plot the features each in their own subplot
        for f in range(num_features):
            axs[f+2].plot(
                self.tst_st.raw_data_with_clusters[orig_columns[f]].iloc[:end_raw],
                label=orig_columns[f])
            axs[f+2].set_title(f"{orig_columns[f]}")
        plt.tight_layout()
        plt.show()

    def plot_raw_data_with_cluster_colours(self, proportion:float = 1):
        end = int(self.tst_st.training_structure.dataset_processed.filtered_data.shape[0] * proportion)
        end = int(self.tst_st.training_structure.dataset_processed.dataset.combined_data.shape[0] * proportion)
        # Define a list of colors (expand this list if you have more clusters)
        colors = ['red', 'blue', 'green', 'black', 'purple','orange', 'yellow', 'pink', 'brown', 'grey']
        # Plot each segment colored by its cluster
        current_chunk_index = 0
        current_sub_counter = 0
        for i in range(1, end):
            xy = [self.tst_st.training_structure.dataset_processed.dataset.combined_data[i - 1],
                  self.tst_st.training_structure.dataset_processed.dataset.combined_data[i]]
            plt.plot([i - 1, i], xy , 
                      color = colors[int(self.tst_st.all_clusters.iloc[current_chunk_index,-1])])

            current_sub_counter += 1
            if current_sub_counter == BLOCK_SIZE:
                current_sub_counter = 0
                current_chunk_index += 1
        plt.xlabel('time (100ths of a second)')
        plt.ylabel('Seismic Amplitude')
        # make the title fit into the graph
        plot_title ='Samples Colored by Cluster Assignment \n'
        num_stations = len(self.tst_st.training_structure.dataset_processed.dataset.stations)
        if num_stations > 7:
            plot_title += f"for {num_stations} stations with {self.tst_st.training_structure.dataset_processed.dataset.axis}"
        else:
            plot_title += f'{self.tst_st.training_structure.dataset_processed.dataset.stations} {self.tst_st.training_structure.dataset_processed.dataset.axis}'

        plt.title(plot_title)
        # create a legend which maps from cluster number to color
        legend_elements = []
        for c in range(self.tst_st.num_clusters):
            legend_elements.append(Line2D([0], [0], color=colors[c], lw=4, label=f'Cluster {c}'))
        
        plt.legend(handles=legend_elements)
        # put the legend outside the plot
        plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
        # resize the plot to fit the legend
        plt.subplots_adjust(right=0.7)
        plt.show()

    @staticmethod
    def get_station_data(station_name='IWEH', filename='all_stations.json'):
        with open(filename, 'r') as f:
            data = json.load(f)
        for d in data:
            if d['name'][:4] == station_name:
                return d
        return None

    def drum_plot_with_cluster_colours(self, proportion:float = 1):
        # Define a list of colors (expand this list if you have more clusters)
        colors = ['red', 'blue', 'green', 'black', 'purple', 'orange', 'yellow', 'pink', 'brown', 'grey']

        num_stations = 10

        A = self.tst_st.training_structure.dataset_processed.dataset.time_scale #eqdsp.dataset.time_scale
        B = self.tst_st.training_structure.dataset_processed.dataset.combined_data
        C = self.tst_st.all_clusters.iloc[:, -1].to_numpy()
        print(f"{C=}")

        # for a 48 hour set
        # A[i][1] is location of 9am-ish
        # A [i][3] is the end of the data
        # So location of next 9am-ish is 1*A[i][3] + A[i+1][1]
        # So location of next 9am-ish is 2*A[i+1][3] + A[i+2][1]
        # or looking into past
        # So location of next 9am-ish is i*A[i-1][3] + A[i][1]
        print(f"{A[:num_stations]=}")
        old_A = A
        new_A = [0 * 0 + old_A[0][1]]
        for i in range(1, len(A[:num_stations])):
            new_A.append(i * old_A[i - 1][3] + old_A[i][1] + 1)
        A = new_A

        print(f"A={A}")
        #B = eqdsp.dataset.combined_data
        # Extract sub-vectors
        num_mins = 15
        subv_len = 60 * SAMPLING_RATE * num_mins

        divider = 250

        sub_vectors = [B[start:start + subv_len] / divider for start in A]
        sub_vectors_chunks = [C[math.floor(start/BLOCK_SIZE):math.floor((start + subv_len)/BLOCK_SIZE)]  for start in A]
       
        print(f"sub_vectors_chunks={sub_vectors_chunks}")

        # Plot sub-vectors vertically next to each other
        plt.figure(figsize=(10, 8))
       

        for idx, sub_vector in enumerate(sub_vectors):
            # Plot each segment colored by its cluster
            current_chunk_index = 0
            current_sub_counter = 0
            divider = 1
            for i in range(1, subv_len):
                xy = [sub_vector[i - 1],
                      sub_vector[i]]
                # print(f"{xy=}")
                xy  = [z/divider+ idx * 5 for z in xy]  #
                #print(sub_vectors_chunks[current_chunk_index])
                plt.plot(xy, [i - 1, i],  
                         color=colors[int(sub_vectors_chunks[idx][current_chunk_index])])

                current_sub_counter += 1
                if current_sub_counter == 500:
                    current_sub_counter = 0
                    current_chunk_index += 1

        stations= self.tst_st.training_structure.dataset_processed.dataset.stations[:num_stations]
        print(stations)
        custom_labels = []
        for idx, station in enumerate(stations):
            custom_labels.append(f"{station}\n({round(self.get_station_data(station)['latitude'], 1)})")
        #print(eqdsp.dataset.stations_data[:num_stations])
        print(custom_labels)
        # Set custom x-axis labels and rotate them
        plt.xticks(ticks=5 * np.arange(0, len(custom_labels)), labels=custom_labels, rotation=90)

        plt.xlabel("Station (latitude)")
        plt.ylabel("Sample Index")
        plt.title(f"Drum plot with Clusters for first {num_stations} stations Z in the 127 station dataset")
        plt.show()

    def plot_raw_data_with_cluster_colours_nonshuffled(self, proportion:float = 1):
        end = int(self.tst_st.raw_data_with_clusters.shape[0] * proportion)
        # Define a list of colors (expand this list if you have more clusters)
        colors = ['red', 'blue', 'green', 'black', 'purple','orange', 'yellow', 'pink', 'brown', 'grey']
        # Plot each segment colored by its cluster
        for i in range(1, end):
            plt.plot([i - 1, i],
                     [self.tst_st.raw_data_with_clusters.iloc[i - 1,0],
                      self.tst_st.raw_data_with_clusters.iloc[i,0]],
                     color=colors[int(self.tst_st.raw_data_with_clusters.iloc[i-1,-1])])
        plt.xlabel('time (100ths of a second)')
        plt.ylabel('Seismic Amplitude')
        # make the title fit into the graph
        plt.title(
            f'Samples Colored by Cluster Assignment \n{self.tst_st.training_structure.dataset_processed.dataset.stations} {self.tst_st.training_structure.dataset_processed.dataset.axis}')
        # create a legend which maps from cluster number to color
        legend_elements = []
        for c in range(self.tst_st.num_clusters):
            legend_elements.append(Line2D([0], [0], color=colors[c], lw=4, label=f'Cluster {c}'))
       
        plt.legend(handles=legend_elements)
        # put the legend outside the plot
        plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
        # resize the plot to fit the legend
        plt.subplots_adjust(right=0.7)
        plt.show()

    # high pass filter
    # has to be static so that it can be used in the map function
    @staticmethod
    def LPF(data: np.ndarray = None):
        # low pass filter
        b, a = butter(BUTTER_ORDER, SAMPLING_RATE//2 - 1, btype='low',
                      fs=SAMPLING_RATE)
        # The filtfilt function performs this bidirectional filtering by
        # applying the filter coefficients both forwards and backwards.
        # This process effectively eliminates the phase distortion
        # introduced by the filter and provides a more accurate
        # representation of the filtered signal in the time domain.
        filtered_data = filtfilt(b, a, data)
        return filtered_data

    # calculate the spectrogram of each 500 sample frame of the raw test data
    # and average them by cluster type
    def plot_spectrogram_by_cluster(self, proportion:float = 1):
        frame_spectrograms = []
        end = int(self.tst_st.training_structure.dataset_processed.dataset.combined_data.shape[0] * proportion)
        # remove any part of end that is not divisible by BLOCK_SIZE
        if end % BLOCK_SIZE != 0:
            end = end - (end % BLOCK_SIZE)
        f = np.fft.rfftfreq(BLOCK_SIZE, 1 / SAMPLING_RATE)
        # find the index of f which equals 1
        index_of_1 = list(f).index(1)
        print(f"{index_of_1=}")
        chunk_index = 0
        for i in range(0, end, BLOCK_SIZE):
            frame = self.tst_st.training_structure.dataset_processed.filtered_data[i:i+BLOCK_SIZE]
            # hanning
            # filter the frame so it has nothing above the nyquist frequency
            frame = frame * np.hanning(len(frame))
            Fx = np.fft.rfft(frame)
            Fx = np.abs(Fx)
            frame_spectrograms.append(Fx)
            chunk_index += 1

        # average the spectrograms by cluster type
        cluster_spectrograms = []
        for cluster in range(self.tst_st.num_clusters):
            cluster_frames = [frame_spectrograms[i] for i in range(len(frame_spectrograms)) if
                              self.tst_st.all_clusters.iloc[i, -1] == cluster]
            if len(cluster_frames) > 0:
                cluster_spectrograms.append(np.average(cluster_frames, axis=0))
            else:
                print("No frames in cluster", cluster)
                # append a zero array if there are no frames in this cluster
                cluster_spectrograms.append(np.zeros(len(f)))
        # plot the spectrograms
        fig, axs = plt.subplots(self.tst_st.num_clusters, 1, figsize=(10, 15))

        for cluster in range(self.tst_st.num_clusters):
            zero_cutoff = cluster_spectrograms[cluster]
            #print(f"{zero_cutoff=}")
            #print(f"{f=}")
            axs[cluster].plot(f[index_of_1:], zero_cutoff[index_of_1:])
            axs[cluster].set_title(f"Cluster {cluster}")
            # make the x and y scales logarithmic
            axs[cluster].set_xscale('log')
            axs[cluster].set_xlabel('Frequency (Hz)')
            axs[cluster].set_ylabel('Amplitude')
            axs[cluster].grid(True, which='both')
        plt.tight_layout()
        plt.show()

    # function to plot a normalised cumulative count of cluster occurences,
    # one for each cluster, but in the same graph
    def get_cumulative_cluster_activity_multi(self, proportion: float = 1):
        # get the cluster activity
        cluster_activity_df = pd.DataFrame(
            self.tst_st.training_structure.features_df['test'].iloc[:,-1].values)
        for c in range(self.tst_st.num_clusters):
            cluster_activity_df[f'Cluster Activity {c}'] = \
                [int(cluster_activity_df.iloc[i, 0] == c)
                 for i in range(len(cluster_activity_df))]
            # now calculate the cumulative sum of the cluster activity
            cluster_activity_df[f'Cum Cluster Activity {c}'] = \
                cluster_activity_df[f'Cluster Activity {c}'].cumsum()
        # plot the cumulative cluster activity on a single graph
        fig, ax = plt.subplots(figsize=(10, 10))
        for c in range(self.tst_st.num_clusters):
            ax.plot(cluster_activity_df[f'Cum Cluster Activity {c}'],
                    label=f'Cluster {c}')
        ax.set_xlabel('Sample')
        ax.set_ylabel('Cumulative Cluster Activity')
        ax.set_title('Cumulative Cluster Activity')
        ax.legend()
        plt.show()


    # method to take as input a list of stations
    # and a range of times,
    # and to then get an ordered list of what percentage of time
    # certain clusters are active
    def get_cluster_activity_by_time(self, station_list: list[str]):
        stations = self.tst_st.training_structure.dataset_processed.dataset.stations

        # find station indices for the station list
        station_indices = []
        for station in station_list:
            station_indices.append(stations.index(station))

        A = self.tst_st.training_structure.dataset_processed.dataset.time_scale  # eqdsp.dataset.time_scale
        B = self.tst_st.training_structure.dataset_processed.dataset.combined_data


        C = self.tst_st.all_clusters.iloc[:, -1].to_numpy()
        print(f"{C=}")

        # although seeking out cluster type, we will only
        # focus on the 48 hours again.

        # A = eqdsp.dataset.time_scale

        # for a 48 hour set
        # A[i][1] is location of 9am-ish
        # A [i][3] is the end of the data
        # So location of next 9am-ish is 1*A[i][3] + A[i+1][1]
        # So location of next 9am-ish is 2*A[i+1][3] + A[i+2][1]
        # or looking into past
        # So location of next 9am-ish is i*A[i-1][3] + A[i][1]

        old_A = A
        new_A = [0 * 0 + old_A[0][1]]
        for i in range(1, len(A)):
            #print(old_A[i])
            try:
                new_A.append(i * old_A[i - 1][3] + old_A[i][1] + 1)
            except:
                # # the none, is to exclude this data
                # the 30000*2 is what should've been excluded from 48 hours
                old_A[i] = [0,0,0,17220000+30000*2,None]
        A = new_A

        # remove all station data that is not in the stations list
        A_stations = []
        for i in range(len(A)):
            if i in station_indices:
                if len(old_A[i]) == 4:
                    A_stations.append(A[i])
                else:
                    print(f"None valid station data at station index {station_indices.index(i)}")
        A = A_stations
        # A contains the indices of the start of each 9am for each piece of data
        # for the stations in the station list
        print(f"{A=}")
        # on the 48 hour clock, daytime is 9am to 5pm
        # nighttime midnight to 5am
        daytime = [9,17]
        nighttime = [24,29]
        day_samples = []
        night_samples = []
        convert_to_samples = SAMPLING_RATE*60*60
        for i in range(len(A)):
            t0 = daytime[0]
            t1 = daytime[1]
            sample_index_9am = A[i]+(t0 - 9)*convert_to_samples
            sample_index_5pm = A[i] + (t1 - 9) * convert_to_samples
            day_samples.append((sample_index_9am, sample_index_5pm))
            t0 = nighttime[0]
            t1 = nighttime[1]
            sample_index_midnight = A[i] + (t0 - 9) * convert_to_samples
            sample_index_5am = A[i] + (t1 - 9) * convert_to_samples
            night_samples.append((sample_index_midnight, sample_index_5am))
        #print(f"{day_samples=}")
        #print(f"{night_samples=}")

        cluster_count = {'day':dict(), 'night':dict()}
        cluster_count['day'] = {'0':0, '1':0, '2':0, '3':0, '4':0}
        cluster_count['night'] = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0}
        sample_set = dict()
        sample_set['day'] = day_samples
        sample_set['night'] = night_samples
        print(sample_set)
        for t in ('day', 'night'):
            #print(f"{sample_set[t]=}")
            for s_e in sample_set[t]:
                s,e  = s_e
                chunk_s = math.floor(s/BLOCK_SIZE)
                chunk_e = math.floor(e/BLOCK_SIZE)
                clusters_s_e = C[chunk_s:chunk_e]
                # count the number of times each cluster appears
                for cluster in clusters_s_e:
                    cluster_count[t][str(cluster)] += 1
        print(f"{cluster_count=}")


    def cumulative_cluster_plot_by_station(self, station: str):
        stations = self.tst_st.training_structure.dataset_processed.dataset.stations

        station_index = stations.index(station)

        A = self.tst_st.training_structure.dataset_processed.dataset.time_scale  # eqdsp.dataset.time_scale
        B = self.tst_st.training_structure.dataset_processed.dataset.combined_data

        C = self.tst_st.all_clusters.iloc[:, -1].to_numpy()
        print(f"{C=}")

        # although seeking out cluster type, we will only
        # focus on the 48 hours again.

        # for a 48 hour set
        # A[i][1] is location of 9am-ish
        # A [i][3] is the end of the data
        # So location of next 9am-ish is 1*A[i][3] + A[i+1][1]
        # So location of next 9am-ish is 2*A[i+1][3] + A[i+2][1]
        # or looking into past
        # So location of next 9am-ish is i*A[i-1][3] + A[i][1]

        old_A = A
        num_nones = 0
        new_A = [0 * 0 + old_A[0][1]]
        for i in range(1, len(A)):
            # print(old_A[i])
            try:
                new_A.append(i * old_A[i - 1][3] + old_A[i][1] + 1)
            except:
                # # the none, is to exclude this data
                # the 30000*2 is what should've been excluded from 48 hours
                old_A[i] = [0, 0, 0, 17220000 + 30000 * 2, None]
        A = new_A

        # remove all station data that is not in the stations list
        A_station = A[station_index]
        A = A_station
        # A contains the indices of the start of each 9am for each piece of data
        # for the stations in the station list
        print(f"{A=}")
        cluster_time_series = {'0': [0], '1': [0], '2': [0], '3': [0], '4': [0]}
        convert_to_samples = SAMPLING_RATE * 60 * 60
        start = math.floor(A / BLOCK_SIZE)
        end = math.floor((A + convert_to_samples * TEST_DATA_HOURS)/BLOCK_SIZE)
        for c_i in range(start, end):
            cluster = C[c_i]
            for i in cluster_time_series.keys():
                if str(cluster) == i:
                    cluster_time_series[i].append(
                        cluster_time_series[i][-1]+1)
                else:
                    cluster_time_series[str(i)].append(
                        cluster_time_series[str(i)][-1])

        for i in cluster_time_series.keys():
            x_data = np.round(np.linspace(0, 48, len(cluster_time_series[i])))
            #x_data = np.linspace(0, 48, len(cluster_time_series[i]))
            plt.plot(x_data, cluster_time_series[i], label=f"Cluster {i}")
        plt.legend()
        plt.ylabel("Cumulative Cluster Count")
        plt.xlabel("Time (Hours)")
        plt.title("Cumulative Cluster Count for 48 hours for Station " + station)
        plt.show()


In [None]:
# EXPERIMENT SECTION

In [None]:
# TRAINING

In [None]:
# TOOL FUNCTIONS FOR TRAINING

def build_and_save_processed_data(
        json_file: str = '',
        h5_file: str = '', axis: str = 'Z',
        start_hour: float = 0, num_hours: float = 0,
        exclude_file: str = '', only_include: int = 999):
    if json_file == "":
        eqdssc = EQDataSetSingleSensorCombined(h5_file,
                                                  axis,
                                                  start_hour,
                                                  num_hours,
                                                  exclude_file=exclude_file,
                                                  only_include=only_include)
                                                    #'markups_all_IWEH_Z.csv')
    else:
        eqdsj = EQDataSetJSON(json_file)
        eqdsj.get_data()
        eqdssc = eqdsj.generate_EQDataSetCombined(axis=axis)

    print(f"{eqdssc=}")
    eqdsp = EQDataSetProcessor(eqdssc)
    eqdsp.compute_features_matrix()
    print(f"{eqdsp=}")
    eqdsp.plot_features(0.0001)

    # extract day month year from current datetime
    # and use it to create a unique filename
    # for the pickle file
    now = datetime.now()
    dt_string = now.strftime("%d%m%Y")
    if json_file == "":
        op_filename = f"EQDataSetProcessor({h5_file}, {axis}, {start_hour}," +\
                      f"{num_hours})_{dt_string})_only_include_{only_include}.pkl"
    else:
        op_filename = f"EQDataSetProcessor({json_file})_{dt_string}.pkl"
    print(eqdsp)
    print("Pickling processor data...")

    with open(op_filename, 'wb') as f:
        pickle.dump(eqdsp, f)
    return eqdsp


def get_station_data(station_name='IWEH', filename='all_stations.json' ):
    with open(filename, 'r') as f:
        data = json.load(f)
    for d in data:
        if d['name'][:4] == station_name:
            return d
    return None


def switch_to_disk():
    path_to_external_drive = "/Volumes/LaCie"

    # Check if the path exists (i.e., if the disk is connected and mounted)
    if os.path.exists(path_to_external_drive):
        os.chdir(path_to_external_drive)
        print(f"Changed working directory to: {os.getcwd()}")
    else:
        print(f"Path '{path_to_external_drive}' does not exist. Ensure the external disk is connected.")
        sys.exit()


def get_all_ev_files():
     all_files = os.listdir()
     ev_files = []
     for file in all_files:
         if file[-3:] == '.h5' and file[:2] == 'ev':
             ev_files.append(file)
     return ev_files


def convert_to_eq_json_format(d: dict, start=0, num_hours=0,
                              axis='Z',json_filename="auto_generated_experiment.json"):
    json_list = []
    for key in d:   # for each event file
        for station in d[key]:
            json_list.append({
                "event_filename": key,
                "station": station,
                "axis": axis,
                "start_hour": start,
                "num_hours": num_hours,
                "exclude_file": ""
            })
    # write json_list as a json file
    json_data = json.dumps(json_list, indent=4)
    #switch_to_internal()
    with open(json_filename, 'w') as f:
        f.write(json_data)
        
def build_station_dict(json_filename="auto_generated_experiment.json",
                       start=48, num_hours=48,axis='Z',
                       events=[]):
    # load in the ev hfiles one by one and extract all station names
    # and store them in a list
    stations = {}
    stations_set = set()
    station_total = 0
    for filename in get_all_ev_files():
        # empty events means do all events
        if len(events) > 0 and filename not in events:
            continue
        print(f"{filename=}")
        sub_stations = []
        with h5py.File(filename, 'r') as file:
            for s in file:
                sub_stations.append(s)
                stations_set.add(s)
                station_total += 1
                print(s, end = ",")
        stations[filename] = sub_stations
        print()
    print(stations)

    convert_to_eq_json_format(stations, start=start,
                              axis=axis, num_hours=num_hours,json_filename=json_filename)
    print(f"Total number of stations: {len(stations_set)} ({station_total})")

    return stations

In [None]:

SKIP_PROCESSING = False
if not SKIP_PROCESSING:
    # this one is the 75+52 files
    file_to_load = "EQDataSetProcessor(experiment_2.1.json)_21082023.pkl"

    load_processed_from_file = False
    if load_processed_from_file:
        print("\nLoading EQDataSetProcessor object from pickle file...\n")
        with open(file_to_load, 'rb') as f:
            eqdsp = pickle.load(f)
    else:

        events = ['ev0001903830.h5', 'ev0000593283.h5']
        axis='Z'
        exp_filename = 'experiment_2.1.json'
        #exp_filename = 'experiment_2.1_minitest.json'
        switch_to_disk()

        build_station_dict(json_filename=exp_filename, start=48, num_hours=48,
                              axis=axis, events=events)
        switch_to_disk()
        eqdsp = build_and_save_processed_data(json_file=exp_filename,
                                              axis=axis)
    print(eqdsp)
print("DONE")

In [None]:

load_training_from_file = False
if not load_training_from_file:
    eqdsc = EQDataSetClusterTrain(eqdsp)
    print(eqdsc)
    dt_string = datetime.now().strftime("%d%m%Y")
    op_filename = f"eqdsc.run_training(defaults)_on_EQDataSetProcessor("
    op_filename += file_to_load
    op_filename += f"_at_{dt_string}.pkl"

    print("\nRunning Training..\n")
    eqdsc.run_training()

    print("\nSaving EQDataSetClusterTrain object to pickle file...\n")
    with open(op_filename, 'wb') as f:
        pickle.dump(eqdsc, f)
    print(eqdsc)
else:
    print("\nLoading EQDataSetClusterTrain object from pickle file...\n")
    #training_file_to_load = "eqdsc.run_training(defaults)_on_EQDataSetProcessor(EQDataSetProcessor(ev0000364000.h5, Z, 0, 0)_13082023.pkl_at_{dt_string}.pkl" #"eqdsc.run_training(defaults)_on_EQDataSetProcessor(ev0000364000.h5, Z, 0, 0)_09082023_at_09082023.pkl"
    training_file_to_load = "eqdsc.run_training(defaults)_on_EQDataSetProcessor(EQDataSetProcessor(experiment_2.1.json)_21082023.pkl_at_23082023.pkl"  # "eqdsc.run_training(defaults)_on_EQDataSetProcessor(ev0000364000.h5, Z, 0, 0)_09082023_at_09082023.pkl"
    with open(training_file_to_load, 'rb') as f:
        eqdsc = pickle.load(f)


print("DONE") 

In [None]:
# EVALUATION

In [None]:

load_tests_from_file = False
if not load_tests_from_file:
    eqdsct = EQDataSetClusterTest(eqdsc, num_clusters=5)
    eqdsct.cluster_test()
    #with open("eqdsct.run_test(defaults)_on_EQDataSetClusterTrain(eqdsc.run_training(defaults)_on_EQDataSetProcessor(EQDataSetProcessor(experiment_2.1.json)_21082023.pkl_at_23082023.pkl)_at_27082023.pkl", 'wb') as f:
    #    pickle.dump(eqdsct, f)
else:
    with open("eqdsct.run_test(defaults)_on_EQDataSetClusterTrain(eqdsc.run_training(defaults)_on_EQDataSetProcessor(EQDataSetProcessor(experiment_2.1.json)_21082023.pkl_at_23082023.pkl)_at_27082023.pkl", 'rb') as f:
        eqdsct = pickle.load(f)


# need to rebuild the processing structure
eqdsct.training_structure.dataset_processed = eqdsp


# mapping back to original data
eqdsct.get_original_cluster_feature_rows()
print(eqdsct.training_structure.features_df['test'])
#eqdsct.get_original_cluster_raw_rows()

# set up the plot object
eqdsctp = EQDataSetClusterTestPlot(eqdsct)
print("DONE")


In [None]:

# plot time evolution of clusters cumulatively for a single station 
#eqdsctp.cumulative_cluster_plot_by_station('KMAH')

print("Preparing plot with colours")
# display the first 0.0001 of all the data with cluster colours
#eqdsctp.plot_raw_data_with_cluster_colours(0.0001)
# display the first 10 stations
eqdsctp.drum_plot_with_cluster_colours(0.0001)

# average spectrogram
print("Preparing plot of spectrogram")
eqdsctp.plot_spectrogram_by_cluster()
#sys.exit()

geo_coastal = ["MSMH", "NCNH", "NOBH", "YMMH", "YMDH", "NMEH", "KMIH", "KKWH", "KAKH"]

geo_inland =["THTH", "SZKH", "MGMH", "KMYH", "NMNH", "OGNH", "NRKH", "SBAH", "SUKH"]

geo_major_urban = ['FKSH','KMAH','KORH','KTMH', 'NGSH', 'RIFH','SDWH','IMRH','YHBH']
geo_non_urban = ['FSWH','GKSH','HKSH','HNRH','HRDH','ICEH','YMDH','YMGH','YUZH']
"""
print("Coastal stations")
eqdsctp.get_cluster_activity_by_time(geo_coastal)
print("Inland stations")
eqdsctp.get_cluster_activity_by_time(geo_inland)
"""
"""
print("Urban stations")
eqdsctp.get_cluster_activity_by_time(geo_major_urban)
print("Non-urban stations")
eqdsctp.get_cluster_activity_by_time(geo_non_urban)
"""


In [None]:
# DATA "MINING" TOOLS

In [None]:
# These are tools that were used to select the data for analysis

def haversine_distance(lat1, lon1, lat2, lon2):
    """
    Calculate the great-circle distance between two points
    on the Earth's surface given their latitude and longitude in decimal degrees.
    """
    # Convert decimal degrees to radians
    lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])

    # Haversine formula
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

    # Radius of Earth in kilometers. Use 3956 for miles
    r = 6371.
    # Calculate the distance
    distance = r * c
    return round(distance,0)


def get_station_overlaps(stations):
    # for each eventfile key, find the intersection of the stations with every other
    # eventfile key using set conversion operations
    exclude_list = ['ev0001903830.h5']
    filtered_station_items = {k: v for k, v in stations.items() if k not in exclude_list}
    #print(filtered_station_items)
    for i in filtered_station_items.items():
        print(f"{i[0]} intersects with:", end=" "   )
        for j in filtered_station_items.items():
            if i[0] != j[0]:
                #print(f"Intersection of {i[0]} and {j[0]}: {set(i[1]) & set(j[1])}")
                if len(set(i[1]) & set(j[1])) > 0:
                    print(j[0], end=", ")
        print()
        
        
def calculate_station_distance():
    # use the json file of station data to calculate the distance between each station
    # and every other station
    station_distances = dict()
    with open("all_stations.json", 'r') as f:
        stations = json.load(f)
    station_pairs_done = []

    for station in stations:
        for station2 in stations:
            if station != station2:
                # append this pair to the list of pairs done as a set
                # so that we don't do the same pair twice
                if {station['name'][:4], station2['name'][:4]} not in station_pairs_done:
                    station_pairs_done.append({station['name'][:4], station2['name'][:4]})
                else:
                    continue

                #print(station2)
                station_distances[
                    station['name']+", "+station2['name']] = (
                    haversine_distance(station['latitude'], station['longitude'],
                                         station2['latitude'], station2['longitude']))
    #print(station_distances)
    #sys.exit()
    # generate list which is sorted by distance
    sorted_distances = sorted(station_distances.items(), key=lambda x: -x[1])
    return sorted_distances

sd = calculate_station_distance()
events = ['ev0000364000', 'ev0000447288']
for i in sd:
    if (events[0] in i[0] and events[1] in i[0]):
        print(i)

In [None]:
# MAPPING TOOLS

In [None]:
import folium

# Create a base map
m = folium.Map(location=[36.2454, 138.5472], zoom_start=5.5, tiles='CartoDB Positron')

with open('experiment_2.1.json', 'r') as f:
    experiment = json.load(f)
with open('all_stations.json', 'r') as f:
    all_stations = json.load(f)

locations = []

stations = []
for station in experiment:
    stations.append(station['station'][:4])

for station in all_stations:
    if station['name'][:4] in stations:
        locations.append([station['latitude'], station['longitude']])

print(locations)


for location in locations:
    folium.CircleMarker(
        location=location,
        radius=3,  # Adjust the size by setting the radius
        color="blue",
        fill=True,
        fill_color="blue"
    ).add_to(m)

# Save the map to an HTML file
m.save('stations_mapped.html')
