# Detecting Outliers in Water Surface Elevation Data

This process involves using a trained autoencoder model to identify possible outliers in a matrix of water surface elevations (X_data). 

These outliers may represent unusual wave patterns or anomalies in the data.

## Prerequisites:
Flux Model: A trained autoencoder model that has been trained on normal water surface elevation data.

## Data:
X_data: A matrix of water surface elevation values (size: 4608 × number of records), where each column corresponds to a 30-minute time series of surface elevations.

X_date: A vector of DateTime values, where each entry corresponds to the starting time of the respective column in X_data.

## Steps to Detect Outliers:

### 1. Load the Required Libraries and Data
Before running the analysis, make sure that you have the necessary libraries and data in place. You will need:

    Flux: For the autoencoder model
    
    Your trained model: It should already be trained on normal data.

### 2. Define the Data and Model
Make sure your data (X_data and X_date) is loaded and accessible.
Ensure your trained model is available.

### 3. Run the Model to Predict Reconstructed Data
The model will predict reconstructed data from the input X_data. 
The reconstruction error will be used to determine whether a specific record is an outlier.

### 4. Calculate Reconstruction Error
The reconstruction error is the difference between the original and reconstructed data. 
Higher errors indicate that the original data is different from the learned pattern and may be an outlier.

### 5. Determine Possible Outliers
Now, determine the columns with the highest reconstruction errors, which could be the outliers. 
For this, you can use a threshold based on the error distribution.

### 6. Map Outliers to Date/Time
Using the outliers indices, map them to their corresponding time in X_date to know which time periods have possible outliers.

### 7. Display the Outliers
You can display the dates and times of the detected outliers.


### Select and read contents of .BVA file

In [None]:
using DataFrames: DataFrame, ncol, nrow
using Dates: Dates, DateTime, Time, unix2datetime, Hour, Minute, Microsecond
using NativeFileDialog: pick_file
using Statistics: median, mean, std
using Plots: Plots, plot, plot!, annotate!, hline!, @layout, text, plotly, font, scatter!
using Printf: @sprintf

#using PlotlyKaleido

#plotly()
#PlotlyKaleido.start(timeout=30)


function get_matches(Data, f23_df)
##################################
    
    # Create a dictionary to store indices of hex strings in Data
    index_dict = Dict{String, Vector{Int}}()
    
    # Populate the dictionary
    for (i, hex_str) in enumerate(Data)
        if haskey(index_dict, hex_str)
            push!(index_dict[hex_str], i)
        else
            index_dict[hex_str] = [i]
        end
    end
    
    # Initialize a vector to store indices
    matching_indices = []
    
    # Iterate through each hex string in f23_df and lookup in the dictionary
    for hex_str in f23_df.Match_vector
        if haskey(index_dict, hex_str)
            push!(matching_indices, index_dict[hex_str][1])
        else
            push!(matching_indices, nothing)  # If no match found, store an empty vector
        end
    end

    f23_df[!,"Data_vector"] = matching_indices

    return(f23_df)

end    # get_matches()

# function to calculate selected parameters from Spectrum synchronisation message (0xF23)
function process_f23(f23_vals)
#######################################
    
    # refer to DWTP (Ver. 16 January2019) Section 4.3 pp.43-44

    # get Timestamp in UTC
    timestamp = unix2datetime(parse(Int, bitstring(f23_vals[3]) * bitstring(f23_vals[4]) * bitstring(f23_vals[5]) * bitstring(f23_vals[6]); base=2))
    
    # convert time to Australian Eastern Standard Time
    timestamp = timestamp + Hour(0)  # Adjust this for the correct time zone

    # get Data Stamp
    data_stamp = parse(Int, bitstring(f23_vals[7]) * bitstring(f23_vals[8]); base=2)

    # get Segments Used
    segments_used = parse(Int, bitstring(f23_vals[9]) * bitstring(f23_vals[10]) * bitstring(f23_vals[11]); base=2)

    # get Sample Number
    sample_number = parse(Int, bitstring(f23_vals[12]) * bitstring(f23_vals[13]); base=2)

    # Create Match Vector
    match_vector = lpad(string(f23_vals[14], base=16), 2, "0")
    for i in 15:22
        match_vector = match_vector * lpad(string(f23_vals[i], base=16), 2, "0")
    end
    
    return(timestamp, segments_used, match_vector, sample_number)
    
end    #  process_f23()


# convert binary data into F23_df and Hex array
function get_hex_array(infil)
#############################
    
    # Read binary data from the input file
    println("Reading BINARY data from ", infil)
    flush(stdout)
    data = reinterpret(UInt8, read(infil))
    
    # Turn the data vector into a matrix of 12 values matching hexadecimal bytes
    cols = 12
    rows = Int(length(data) / cols)
    mat = reshape(view(data, :), cols, :)
    
    # Calculate the Heave, North, and West displacements
    hex_matrix = string.(mat'[:,1:9], base=16, pad=2)
    Data = [join(row) for row in eachrow(hex_matrix)]
    
    println("All file data read!")
    
    # Interleave the last 3 matrix columns (10, 11, 12) to form the packet vector
    packet = collect(Iterators.flatten(zip(mat[10,:], mat[11,:], mat[12,:])))
    
    # Find all occurrences of 0x7e in the packet vector
    aa = findall(x -> x == 0x7e, vec(packet))
    
    # Create DataFrame to hold the processed data
    f23_df = DataFrame(Date = DateTime[], Segments = Int[], Match_vector = String[], Sample_number = Int[])
    
    # Decode the packet data into messages
    max_val = length(aa) - 1
    
    for i in 1:max_val
        first = aa[i] + 1
        last = aa[i + 1]
        
        if (last - first > 1)
            decoded = packet[first:last-1]
            
            # Handle the 0x7d escape sequences (XOR with 0x20)
            bb = findall(x -> x == 0x7d, decoded)
            for ii in bb
                decoded[ii + 1] = decoded[ii + 1] ⊻ 0x20
            end
            deleteat!(decoded, bb)
            
            # If the message is F23 (0x23)
            if decoded[2] == 0x23
                timestamp, segments_used, match_vector, sample_number = process_f23(decoded)
                push!(f23_df, [timestamp, segments_used, match_vector, sample_number])
            end
        end
    end
    
    # Remove duplicates from f23_df
    f23_df = unique(f23_df);

    return(f23_df, Data)
    
    end    # get_hex_array()


function get_start_end_dates(f23_df,found_list)
###############################################
    
    start_date = f23_df[found_list[1],:].Date - Minute(30) # <------- NOTE subtracted 30min from start_date to match Waves4 results
    segments = f23_df[found_list[1],:].Segments
#   match_vector = f23_df[found_list[1],:].Match_vector
    sample_nos = f23_df[found_list[1],:].Sample_number
    data_vector = f23_df[found_list[1],:].Data_vector
    start_val = data_vector - Int(sample_nos/2) + 1
    end_val = data_vector
    
    return(start_date,start_val, end_val)
    
end    #(get_start_end_dates)
    
  
function get_displacement(Data, start_val, end_val)
################################################
# Decode the real time data to displacements - See DWTP (16 Jan 2019) 2.1.1 p. 19    
    
    arry = collect(Iterators.flatten(zip(SubString.(Data, start_val, end_val),SubString.(Data, start_val+9, end_val+9))));
    displacements = [parse(Int, SubString.(i, 1, 1), base=16)*16^2 + parse(Int, SubString.(i, 2, 2), base=16)*16^1 + parse(Int, SubString.(i, 3, 3), base=16)*16^0 for i in arry]    
    
    displacements[findall(>=(2048), displacements)] = displacements[findall(>=(2048), displacements)] .- 4096;
    displacements = 0.457*sinh.(displacements/457)    # see DWTP p.19 (16)
   
    return(displacements)
    
end    # get_displacement()


function get_hnw(Data,start_val,end_val)
######################################## 
    
    # get WSEs for desired 30-minute record
    heave = get_displacement(Data[start_val:end_val,:], 1, 3);              
    north = get_displacement(Data[start_val:end_val,:], 4, 6);
    west = get_displacement(Data[start_val:end_val,:], 7, 9);
    
    # Check for missing or extra points in data
    for wse in [heave, north, west]
        
        wse_length = length(wse)
        
        if wse_length > REC_LENGTH

            # truncate if too long
            wse = wse[1:REC_LENGTH]
            
        else

            # zero pad if too short (leave it unchanged if right length)
            append!(wse,zeros(REC_LENGTH-wse_length))
            
        end      

    end
    
    return (heave, north, west)
    
end    # get_hnw()


# Function to calculate confidence limits
function calc_confidence_limits(data, confidence_interval)
##########################################################
    
    mean_val = mean(data)
    std_dev = std(data)
    upper_limit = mean_val + confidence_interval * std_dev
    lower_limit = mean_val - confidence_interval * std_dev
    
    return (lower_limit, upper_limit)
    
end    # calc_confidence_limits()


# Function to compute modified z-scores and find outliers
function modified_z_score(data, threshold)
##########################################
    
    med = median(data)
    mad = median(abs.(data .- med))
    mod_z_scores = 0.6745 * (data .- med) ./ mad
    outlier_indices = findall(x -> abs(x) > threshold, mod_z_scores)
    
    return(outlier_indices, mod_z_scores)
    
end    # modified_z_score()


# Function for dynamic threshold based on mean wave height
function dynamic_z_score_thresholdXXX(heave, base_threshold=3.0, k=0.5)
####################################################################
#==
    mean_wave_height = mean(heave)
    std_wave_height = std(heave)
    dynamic_threshold = base_threshold * (1 + k * (mean_wave_height / std_wave_height))
==#
    max_threshold = 3.29
    
    heave_range = maximum(heave) - minimum(heave)
    std_wave_height = std(heave)

    # Scale the threshold based on the range of heave values
    # Adjust the scaling factor (e.g., 0.1) as needed to fit your data
    scaling_factor = 0.1 * (heave_range / std_wave_height)

    # Calculate the dynamic threshold
    dynamic_threshold = base_threshold + scaling_factor

    # Clamp the threshold between the defined limits
    dynamic_threshold = clamp(dynamic_threshold, base_threshold, max_threshold)
    
    return(dynamic_threshold)
    
end    # dynamic_z_score_threshold()


function pad_or_truncate(record, target_length=REC_LENGTH)
####################################################

    length(record) < target_length ? vcat(record, zeros(Float32, target_length - length(record))) :
                                     record[1:target_length]

end    # pad_or_truncate()


function get_heave(Data, f23_df)
################################
    
    heave_array = []
    X_date = []

    println("Calculating Heave values now!")
    
    for idx in 1:nrow(f23_df)

        if !isnothing(f23_df.Data_vector[idx])
    
            start_date, start_val, end_val = get_start_end_dates(f23_df,idx)
            if start_val > 0
                print(".")
                heave, north, west = get_hnw(Data,start_val,end_val)

                # ensure we have REC_LENGTH data points
                push!(heave_array,pad_or_truncate(heave, REC_LENGTH))
                push!(X_date,start_date)
            end

        end
    
    end

    return(hcat(heave_array...), X_date)

end    # get_heave()


# Need to check first row of the f23_df in case 23:00 is stored there
function f23_first_row_check(f23_df)
################################
    
    # Get the first row of the DataFrame
    first_row = first(f23_df)
    
    # Check if the time of the first row's Date column is 23:00:00
    time_of_first_row = Time(first_row.Date)

    if time_of_first_row == Time(23, 0, 0)

        if ismissing(first_row.Data_vector) || isnothing(first_row.Data_vector) || isnan(first_row.Data_vector)
            f23_df = f23_df[2:end, :]  # Drop the first row
        end

    end
    
    return(f23_df)
    
    end    # f23_first_row_check()


###############################################################################
###############################################################################
###############################################################################

using Sockets

#global X_data = Matrix{Float32}(undef, 0, 0)

hostname = gethostname()
println("The name of the computer is: ", hostname)

if hostname == "QUEENSLAND-BASIN"
    
    display("text/html", "<style>.container { width:100% !important; }</style>")
    
else
    
    display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))    
    
end


REC_LENGTH = 4608       # Number of WSE's in a Mk4 30-minute record
SAMPLE_FREQUENCY = 2.56 # Mk4 sample frequency in Hertz
SAMPLE_LENGTH = 1800    # record length in seconds
SAMPLE_RATE = Float64(1/SAMPLE_FREQUENCY) # sample spacing in seconds

#########################################################################################################################
##    confidence_interval = 2.576  # corresponds to a 99% confidence interval (for a normal distribution)
##    confidence_interval = 3.0    # corresponds to a 99.73% confidence interval (for a normal distribution)    
##    confidence_interval = 3.29   # corresponds to a 99.9% confidence interval (for a normal distribution)
#########################################################################################################################

# Widen screen for better viewing
display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))
display("text/html", "<style>.container { width:100% !important; }</style>")

initial_path = "F:\\Card Data\\"
infil = pick_file(initial_path)

f23_df, Data = get_hex_array(infil)

f23_df = get_matches(Data, f23_df)

# remove those vectors from F23 df that are not located in the Data vector df
f23_df = f23_first_row_check(f23_df)

X_data, X_date = get_heave(Data, f23_df);

X_data = Float32.(X_data)

println(string(length(X_date))," records processed.\n")
println("\nNow plot heave for each record to see suspect data!")
flush(stdout)

### Plot each record

In [None]:
# Function for dynamic threshold based on wave amplitudes
function dynamic_z_score_threshold(heave, base_threshold=3.0, k=0.5)
    ####################################################################
    
    # Calculate the amplitude of heave
    amplitude = abs.(heave)
    
    # Calculate mean and standard deviation of the amplitudes
    mean_amplitude = mean(amplitude)
    std_amplitude = std(amplitude)
    
    # Calculate the range of amplitudes
    range_amplitude = maximum(amplitude) - minimum(amplitude)
    
    # Calculate dynamic threshold
    dynamic_threshold = base_threshold * (1 + k * (range_amplitude / std_amplitude))
    
    # Cap the dynamic threshold
    max_threshold = 3.29  # or any other maximum threshold you prefer
    dynamic_threshold = min(dynamic_threshold, max_threshold)
    
    return dynamic_threshold
end  # dynamic_z_score_threshold()


##==
# Loop through wave records
for ii in 1:length(X_date)
    
    # Initialize variables
    start_time = X_date[ii]
    heave = X_data[:, ii]
    end_time = start_time + Minute(30)
    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)

    # Plot initialization
    p1 = plot(size=(2000, 300), dpi=100, framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
        legend=:topright, xtickfont=font(8), ytickfont=font(8), bottommargin = 10Plots.mm,
        grid=true, gridlinewidth=0.125, gridstyle=:dot, gridalpha=1)
    
    tm_tick = range(start_time, end_time, step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")
    
    # Calculate dynamic confidence interval
    confidence_interval = dynamic_z_score_threshold(heave)

    # Identify z_scores using modified z-score
    z_score_indices, mod_z_scores = modified_z_score(heave, confidence_interval)
    if !isempty(z_score_indices)
        scatter!(p1, xvals[z_score_indices], heave[z_score_indices], 
            markersize=4, markerstrokecolor=:red, markerstrokewidth=1, 
            markercolor=:white, markershape=:circle, label="")
    end

    # Plot confidence limits
    confidence_limits = calc_confidence_limits(heave, confidence_interval)
    hline!(p1, [confidence_limits[1], confidence_limits[2]], color=:red, lw=0.5, linestyle=:dash, label="")

    # Plot heave data
    plot!(p1, xvals, heave, xlims=(xvals[1], xvals[end]), lw=0.5, lc=:blue, alpha=0.5, 
        xticks=(tm_tick, ticks), label="")

    # Annotate plot with the number of outliers and confidence interval
    num_outliers = length(z_score_indices)
    suspect_string = string("  ", string(ii)," ",Dates.format(start_time, "yyyy-mm-dd HH:MM"), " - ", num_outliers, " Possible outliers") # using Confidence Interval of ", 
##        @sprintf("%.2f", confidence_interval))
    annotate!(p1, xvals[1], maximum(heave) * 0.9, text(suspect_string, :left, 10, :blue))

    display(p1)

end
#==#

### Plot statistics for selected range of records

In [None]:
using Statistics
using Plots
using DSP

arry_length = 256

for ii in 120:150
    
    heave = X_data[:, ii]
    x_date = Dates.format(X_date[ii], "yyyy-mm-dd HH:MM")

    start_time = X_date[ii]
    end_time = start_time + Minute(30)
    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)

    tm_tick = range(start_time, end_time, step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")

    p1 = plot(xvals, heave, size=(2000,300), label="", xlims=(start_time,end_time), framestyle=:box, title=x_date, xticks=(tm_tick, ticks), bottommargin = 10Plots.mm, )

    for idx in xvals[collect(1:arry_length:4608)]
        p1 = plot!([idx, idx], [minimum(heave), maximum(heave)], label="", color=:red, linestyle=:dash)
    end   
    #hline!([final_median + (3.29 * final_std)], label="Median")
    
    suspect_string = string("  ", string(ii)," ",Dates.format(start_time, "yyyy-mm-dd HH:MM"))
    annotate!(p1, xvals[1], maximum(heave) * 0.9, text(suspect_string, :left, 10, :blue))
    
    display(p1)

    #########################################################################################################################
    ##    confidence_interval = 2.576  # corresponds to a 99% confidence interval (for a normal distribution)
    ##    confidence_interval = 3.0  # corresponds to a 99.73% confidence interval (for a normal distribution)    
    ##    confidence_interval = 3.29  # corresponds to a 99.9% confidence interval (for a normal distribution)
    #########################################################################################################################
    ##amplitude = abs.(heave)
    arry = arraysplit(heave, arry_length, 0)
    ##arry = arraysplit(amplitude, arry_length, 0)

    p1 = plot(size=(2000,300), xlims=(start_time,end_time), 
        bottommargin = 10Plots.mm, xticks=(tm_tick, ticks),
        fg_legend=:transparent, bg_legend=:transparent,  legend=:topleft, framestyle=:box)

    # find the mean of each 256-element array
    means = mean.(arry)
    medians = median.(arry)
    stds = std.(arry)

    #find the median value of the means
    median_of_means = median(means)

    # plot the means of each array
    p1 = scatter!(xvals[Int.(collect(0:arry_length:4608-256) .+ arry_length/2)], means, marker=:circle, label="Means")
    p1 = scatter!(xvals[Int.(collect(0:arry_length:4608-256) .+ arry_length/2)], medians, marker=:diamond, ms=:5, label="Medians")
    p1 = scatter!(xvals[Int.(collect(0:arry_length:4608-256) .+ arry_length/2)], stds, marker=:xcross, markerstrokewidth=5, ms=:5, label="Std. Dev's")

    p1 = hline!([(median(stds)+2*std(stds))], ls=:dot, lc=:red, lw=:2, label="")
    p1 = hline!([(median(stds)-2*std(stds))], ls=:dot, lc=:red, lw=:2, label="")

    for idx in xvals[collect(1:arry_length:4608)]
        p1 = plot!([idx, idx], [(median(stds)-2*std(stds)), (median(stds)+2*std(stds))], label="", color=:red, linestyle=:dash)
    end 
    
    display(p1)
    
end

### Do spectral plot of selected record

In [None]:
using DSP, LinearAlgebra
using LombScargle
using ToeplitzMatrices
using Plots, Printf

# Function to calculate AR coefficients using Yule-Walker equations
function yule_walker(data, order)
#################################
# Created in collaboration with OpenAI's ChatGPT.
# Uses the Yule-Walker equations to estimate autoregressive (AR) model coefficients.
# This function is commonly applied in time series analysis for spectral estimation and noise reduction.
    
    n = length(data)
    
    # Autocorrelation estimation
    autocorr = [sum(data[1:n-k] .* data[1+k:n]) / n for k in 0:order]
    
    # Construct the Toeplitz matrix for solving the Yule-Walker equations
    R = Toeplitz(autocorr[1:order], autocorr[1:order])
    r = autocorr[2:order+1]
    
    # Solve for AR coefficients
    a = R \ r
    
    return( vcat(1.0, -a))  # Include 1 for AR model definition
    
end    # yule_walker()


# Function to compute the MEM-based power spectral density
function mem_psd(data, order, SAMPLE_RATE, num_points=4096)
###########################################################
# Created in collaboration with OpenAI's ChatGPT.
# Implements the Maximum Entropy Method (MEM) for Power Spectral Density (PSD) estimation.
# This method is useful for estimating spectral density, especially for time series with limited data points.
    
    ar_coeffs = yule_walker(data, order)
    
    freqs = range(0, stop=SAMPLE_RATE/2, length=num_points)
    psd = Float64[]
    
    for f in freqs
        omega = 2 * π * f / SAMPLE_RATE
        denom = abs(sum(ar_coeffs .* exp.(-im * omega * (0:order))))
        push!(psd, (1 / (denom^2)) * SAMPLE_RATE)  # Normalize by sample rate
    end
    
    return(freqs, psd)
    
end    # mem_psd()


# Band averaging function
function band_average(freqs, psd, bin_size)
###########################################
    
    max_freq = maximum(freqs)
    bins = 0:bin_size:max_freq
    averaged_psd = Float64[]
    averaged_freqs = Float64[]

    for i in 1:length(bins)-1
        mask = (freqs .>= bins[i]) .& (freqs .< bins[i+1])
        if any(mask)
            avg_psd = mean(psd[mask])
            push!(averaged_psd, avg_psd)
            push!(averaged_freqs, (bins[i] + bins[i+1]) / 2)  # Center frequency
        end
    end

    return(averaged_freqs, averaged_psd)

end    # band_average()


# Function to calculate total energy and low-frequency energy
function calculate_energy(freqs, psd, low_freq_threshold)
#########################################################
    
    total_energy = sum(psd) * (freqs[2] - freqs[1])  # Area under the entire PSD
    low_freq_mask = freqs .<= low_freq_threshold
    
    low_freq_energy = sum(psd[low_freq_mask]) * (freqs[2] - freqs[1])  # Area under low-frequency PSD
    
    percentage_low_freq_energy = (low_freq_energy / total_energy) * 100
    
    return(total_energy, low_freq_energy, percentage_low_freq_energy)
    
end    # calculate_energy()


function normalize_psd(psd)
###########################
    
    return(psd./psd[argmax(psd)])  # Normalize to make the area under the curve equal to 1
    
end    # normalize_psd()


###############################################################################
###############################################################################
###############################################################################


SAMPLE_RATE = 2.56
NYQUIST = SAMPLE_RATE/2
order = 30
num_points = 4096
low_freq_threshold = 0.06667  # 15s corresponds to this frequency

ii = 67
arry_length = 256

heave = X_data[:, ii]
x_date = Dates.format(X_date[ii], "yyyy-mm-dd HH:MM")


##################################
# Calculate wave spectra using periodogram()
##################################
## Compute the periodogram over a full 4096-point segment
data_segment = heave
psd_result = periodogram(data_segment, fs=SAMPLE_RATE)

# Extract frequency and power spectral density values
freqs_fft = psd_result.freq
psd_fft = psd_result.power

bin_size = 0.005
averaged_freqs_fft, averaged_psd_fft = band_average(freqs_fft, psd_fft, bin_size)

normal_psd_fft = normalize_psd(averaged_psd_fft)

##################################
# Calculate wave spectra using lombScargle() - see https://juliaastro.org/LombScargle.jl/stable/
##################################
# Generate timestamps (required as input by the lombscargle() function)
timestamps = collect(0:1/2.56:1800)[1:end-1]

# Calculate Lomb-Scargle spectrum
lombscargle_spectrum = lombscargle(timestamps, heave;
                                   minimum_frequency=1e-5,  # Small positive start frequency
                                   maximum_frequency=1.28,
                                   samples_per_peak=400)

# Extract frequency and power values
ls_freqpower = freqpower(lombscargle_spectrum)
ls_frequencies = ls_freqpower[1]
power = ls_freqpower[2]

# Remove any remaining Inf values if present (unlikely, but a safeguard)
power .= replace(power, Inf => NaN)
power .= replace(power, NaN => minimum(skipmissing(power)))

# Band-average the Lomb-Scargle spectrum
averaged_lombscargle_freqs, averaged_lombscargle_fft = band_average(ls_frequencies, power, bin_size)

# Normalize the lombscargle fft
normal_lombscargle_fft = normalize_psd(averaged_lombscargle_fft)

##################################
# Calculate wave spectra using MEM
##################################

# Compute PSD using MEM
freqs_mem, psd_mem = mem_psd(heave, order, SAMPLE_RATE, num_points)

# Perform band averaging on MEM results
averaged_freqs_mem, averaged_psd_mem = band_average(freqs_mem, psd_mem, bin_size)

normal_psd_mem = normalize_psd(averaged_psd_mem)

# Calculate energies
total_energy, low_freq_energy, percentage_low_freq_energy = calculate_energy(averaged_freqs_fft, averaged_psd_fft, low_freq_threshold)

# Output results
println("Total Energy: $total_energy")
println("Low-Frequency Energy: $low_freq_energy")

@printf("Percentage of Low-Frequency Energy: %4.2f%%",percentage_low_freq_energy)

# Plot initialization
p1 = plot(size=(1200, 800), dpi=100, framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
    xlabel="Frequency (Hz)", xlims=(0,0.6), ylims=(0,Inf),
    ylabel="Normalized Spectral Density", title=x_date,
    legend=:topright, xtickfont=font(8), ytickfont=font(8),
    leftmargin = 20Plots.mm, bottommargin = 20Plots.mm,
    grid=true, gridlinewidth=0.125, gridstyle=:dot, gridalpha=1)

p1 = plot!(averaged_lombscargle_freqs, normal_lombscargle_fft, lw=:6, lc=:yellow, alpha=:0.5, label="Lomb–Scargle FFT spectra")

p1 = plot!(averaged_freqs_fft, normal_psd_fft, lw=:2, lc=:blue, alpha=:0.5, label="Periodogram FFT spectra")

p1 = plot!(averaged_freqs_mem, normal_psd_mem, lw=:2, ls=:dot, lc=:pink, alpha=:0.75, label="MEM spectra")

p1 = vline!([low_freq_threshold], lw=:2, lc=:red, ls=:dash, label="Low frequency cut-off - 15s\n")  

# Fill the area under the curve from 0 Hz to low frequency cutoff
fillrange = zeros(length(averaged_freqs_fft))  # Base level for filling (y=0)
fill_mask = averaged_freqs_fft .<= low_freq_threshold  # Mask for frequencies below cutoff

# Fill the area under the PSD curve for the specified frequency range
p1 = plot!(averaged_freqs_fft[fill_mask], normal_psd_fft[fill_mask], fillrange=fillrange[fill_mask], 
      fillalpha=0.5, color=:red, label="Low frequency energy: "*  @sprintf("%.2f%%", percentage_low_freq_energy))

display(p1)

### Plot Percentage of low-freq energy

In [None]:
using DSP, LinearAlgebra
using ToeplitzMatrices
using Plots, Printf

# Function to calculate AR coefficients using Yule-Walker equations
function yule_walker(data, order)
#################################
    
    n = length(data)
    
    # Autocorrelation estimation
    autocorr = [sum(data[1:n-k] .* data[1+k:n]) / n for k in 0:order]
    
    # Construct the Toeplitz matrix for solving the Yule-Walker equations
    R = Toeplitz(autocorr[1:order], autocorr[1:order])
    r = autocorr[2:order+1]
    
    # Solve for AR coefficients
    a = R \ r
    
    return( vcat(1.0, -a))  # Include 1 for AR model definition
    
end    # yule_walker()


# Function to compute the MEM-based power spectral density
function mem_psd(data, order, SAMPLE_RATE, num_points=4096)
###########################################################
    
    ar_coeffs = yule_walker(data, order)
    
    freqs = range(0, stop=SAMPLE_RATE/2, length=num_points)
    psd = Float64[]
    
    for f in freqs
        omega = 2 * π * f / SAMPLE_RATE
        denom = abs(sum(ar_coeffs .* exp.(-im * omega * (0:order))))
        push!(psd, (1 / (denom^2)) * SAMPLE_RATE)  # Normalize by sample rate
    end
    
    return(freqs, psd)
    
end    # mem_psd()


# Band averaging function
function band_average(freqs, psd, bin_size)
###########################################
    
    max_freq = maximum(freqs)
    bins = 0:bin_size:max_freq
    averaged_psd = Float64[]
    averaged_freqs = Float64[]

    for i in 1:length(bins)-1
        mask = (freqs .>= bins[i]) .& (freqs .< bins[i+1])
        if any(mask)
            avg_psd = mean(psd[mask])
            push!(averaged_psd, avg_psd)
            push!(averaged_freqs, (bins[i] + bins[i+1]) / 2)  # Center frequency
        end
    end

    return(averaged_freqs, averaged_psd)

end    # band_average()


# Function to calculate total energy and low-frequency energy
function calculate_energy(freqs, psd, low_freq_threshold)
#########################################################
    
    total_energy = sum(psd) * (freqs[2] - freqs[1])  # Area under the entire PSD
    low_freq_mask = freqs .<= low_freq_threshold
    
    low_freq_energy = sum(psd[low_freq_mask]) * (freqs[2] - freqs[1])  # Area under low-frequency PSD
    
    percentage_low_freq_energy = (low_freq_energy / total_energy) * 100
    
    return(total_energy, low_freq_energy, percentage_low_freq_energy)
    
end    # calculate_energy()


function normalize_psd(psd)
###########################
    
    return(psd./psd[argmax(psd)])  # Normalize to make the area under the curve equal to 1
    
end    # normalize_psd()


###############################################################################
###############################################################################
###############################################################################


SAMPLE_RATE = 2.56
NYQUIST = SAMPLE_RATE/2
order = 30
num_points = 4096
low_freq_threshold = 0.06667  # 15s corresponds to this frequency

##arry_length = 256

##heave = X_data[:, ii]
##x_date = Dates.format(X_date[ii], "yyyy-mm-dd HH:MM")

percentages = []

println("Rec. No.        Date       Low Freq. %")

# Loop through each column of X_data
for i in 1:size(X_data, 2)

    # Extract heave data and corresponding date
    heave = X_data[:, i]
    x_date = Dates.format(X_date[i], "yyyy-mm-dd HH:MM")

    # Compute the periodogram over a full 4096-point segment
    data_segment = heave
    psd_result = periodogram(data_segment, fs=SAMPLE_RATE)

    # Band-average the power spectral density (PSD) results
    averaged_freqs_fft, averaged_psd_fft = band_average(psd_result.freq, psd_result.power, bin_size)

    # Normalize PSD for energy calculations
    total_energy, low_freq_energy, percentage_low_freq_energy = calculate_energy(averaged_freqs_fft, normalize_psd(averaged_psd_fft), low_freq_threshold)

    # Set a flag for high % of low-frequency energy or possible spikes
    flag = percentage_low_freq_energy >= 5 ? "<=== anomaly" : percentage_low_freq_energy >= 1 ? "<--- spike" : ""

    # Print results
    @printf("  %3d    %s    %.2f%% %s\n", i, x_date, percentage_low_freq_energy, flag)
    # Store percentage low-frequency energy
    push!(percentages, percentage_low_freq_energy)

end

tm_tick = range(X_date[1], X_date[end], step=Hour(6))
ticks = Dates.format.(tm_tick, "dd HH:MM")

title = Dates.format(X_date[1], "yyyy-mm-dd HH:MM")* " to "* Dates.format(X_date[end], "yyyy-mm-dd HH:MM")


plot(X_date,percentages, size=(2000,600), dpi=100, lw=1, lc=:blue, alpha=0.5, title=title,
    xlims=(X_date[1], X_date[end]), xticks=(tm_tick, ticks), 
    ylabel="Low-frequency energy (%)",
    legend=:topright, xtickfont=font(8), ytickfont=font(8),
    leftmargin = 20Plots.mm, bottommargin = 20Plots.mm,
    framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
    grid=true, gridlinewidth=0.125, gridstyle=:dot, gridalpha=1, label="")


# Define the threshold line at 5%
percentage_threshold = 5

hline!([percentage_threshold], lw=:2, lc=:red, ls=:dash, label="")

### Select directory containing .BVA files

In [None]:
using DataFrames: DataFrame, ncol, nrow
using Dates: Dates, DateTime, Time, unix2datetime, Year, Month, Day, Hour, Minute, Microsecond
using NativeFileDialog: pick_folder
using Statistics: median, mean, std
using Plots: Plots, plot, plot!, annotate!, hline!, @layout, text, plotly, font, scatter!
using Printf: @sprintf
using DSP: periodogram

# Compute power spectrum using periodogram
function calc_spectra(heave, fs)
################################
    
    ps = periodogram(heave, fs=fs)

    # Define the averaging width for 0.005 Hz band spacing
    bin_width = Int(round(0.005 / (ps.freq[2] - ps.freq[1])))  # Number of points per band
    
    # Preallocate arrays for the averaged frequencies and powers
    freqs = []
    powers = []
    
    # Start from the first frequency we want to capture
    target_freq = 0.005
    
    # Get the maximum frequency limit
    max_freq = maximum(ps.freq)
    
    # Loop through the power spectrum to calculate the averages
    for i in 1:bin_width:length(ps.power) - bin_width
        
        # Calculate the mean frequency for the current segment
        current_freq = mean(ps.freq[i:i+bin_width-1])
        
        # While the current frequency is less than or equal to the target frequency, capture it
        while current_freq >= target_freq && target_freq <= max_freq
            push!(freqs, target_freq)
            push!(powers, mean(ps.power[i:i+bin_width-1]))
            target_freq += 0.005  # Increment the target frequency
        end
        
    end
    
    # Check if the last target_freq needs to be added
    if target_freq <= max_freq
        push!(freqs, target_freq)
        push!(powers, mean(ps.power[end-bin_width+1:end]))  # Average for the last bin
    end

    return(freqs, powers) 

end    # calc_spectra()


#####################################################################################
#####################################################################################
#####################################################################################

# Widen screen for better viewing
display(HTML("<style>.jp-Cell { width: 140% !important; }</style>"))

sample_frequency = 2.56

# select BVA directory
bva_directory = pick_folder()

# build list of all bva files in selected directory
bva_files = filter(x->occursin(".BVA",uppercase(x)), readdir(bva_directory))
bva_files = filter(x -> x != "19700101.BVA", bva_files)

for ii in 1:length(bva_files)
    println(ii," ",bva_files[ii])
end

### Do spectral plot of each record in selected .BVA file

In [None]:
using Plots

p1 = plot(size=(800,600), xlims=(0,0.6), ylims=(0,Inf), framestyle=:box)
infil = ""

for ii in 21 #1:length(bva_files)

    infil = bva_directory * "\\" * bva_files[ii]

    f23_df, Data = get_hex_array(infil)

    f23_df = get_matches(Data, f23_df)

    # remove those vectors from F23 df that are not located in the Data vector df
    f23_df = f23_first_row_check(f23_df)
    
    global X_data, X_date = get_heave(Data, f23_df)

    num_cols = size(X_data)[2]

    global spectra = Matrix{Float64}(undef, 256, num_cols)

    for jj in 1:num_cols

        heave = X_data[:,jj]

        f2, Pden2 = calc_spectra(heave, sample_frequency)
        spectra[:, jj] = Pden2

    end

end

x_fill = 1/15    # set 15s as cut-off limit for low energy part of spectra
y_max = maximum(spectra) # maximum spectra value in matrix

infil_path = split(infil,"\\")
site_name = infil_path[end-2]
infil_name = infil_path[end]

title = site_name * " " * Dates.format.(X_date[1], "yyyy-mm-dd HH:MM") * " to " * Dates.format.(X_date[end], "yyyy-mm-dd HH:MM")
p1 = plot(size=(800,600), xlims=(0,0.6), ylims=(0,Inf), title=title, framestyle=:box)

# Fill the area to the left of the vertical line
p1 = plot!([0, x_fill, x_fill, 0], [0, 0, y_max, y_max], seriestype=:shape, alpha=0.035, color=:red, label="")

f2 = 0.005:0.005:1.28
p1 = plot!(f2, spectra, lw=:0.5, lc=:lightgrey, label="")

p1 = vline!([x_fill], lw=:1, lc=:red, ls=:dash, label="")

try       
    plot_file = ".\\Plots\\" * site_name * "_" * replace(infil_name, ".BVA" => "_spectral_plot.png")

    # Output plot file name
##    savefig(plot_file)
    println("\nPlot file saved as ",plot_file)
catch
    "Alert: Plot not saved!"
    flush(stdout)
end

display(p1)

### Do contour plot of spectra for both frequency (Hz) and wave period (seconds)

In [None]:
using Plots
using Dates: Dates, Date, DateTime, Time, unix2datetime, Year, Month, Day, Hour, Minute, Microsecond

# Define your starting point for midnight after the first X_date entry
first_midnight = DateTime(Date(X_date[1]) + Day(1))

# display plots to screen
tm_tick = range(first_midnight,last(X_date),step=Hour(6))
ticks = Dates.format.(tm_tick, "dd/mm HH00")

title = site_name * " " * Dates.format.(X_date[1], "yyyy-mm-dd HH:MM") * " to " * Dates.format.(X_date[end], "yyyy-mm-dd HH:MM")
p1 = plot(xlabel="Date", ylabel="Frequency (Hz.)", xtickfontsize=7, 
    leftmargin = 15Plots.mm, bottommargin = 15Plots.mm,)

p1 = contourf!(X_date, f2, spectra, lw=0.25, c=cgrad(:Spectral, rev=true), clims=(0,y_max), levels=10, fill=true, ylims=(0,0.4), xticks=(tm_tick,ticks))

# draw grid lines on plot
for i in 0:0.1:0.6
    p1 = hline!(p1, [i], lw=0.5, c=:white, label="")
end

for i in tm_tick
    p1 = vline!(p1, [i], lw=0.5, c=:white, label="")
end

p1 = hline!(p1, [1/15], lw=1, c=:yellow, ls=:dash, label="")

# Convert frequency (f2) to period (T) and reverse both periods and spectra for ascending order
periods = reverse(1.0 ./ f2)       # periods in ascending order
spectra_reversed = reverse(spectra, dims=1)  # reverse spectra along y-axis (rows)

# get peak spectral value in matrix
peak_spectra = 1/f2[argmax(spectra)[1]]

# Define y_ticks period labels in seconds
y_ticks = [5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200, 250]

y_ticks = y_ticks[1:searchsortedfirst(y_ticks, peak_spectra)+1]

# Plot with converted y-axis and reversed spectra
p2 = plot(xlabel="Date", ylabel="Period (seconds)", xtickfontsize=7, 
    leftmargin = 15Plots.mm, bottommargin = 15Plots.mm,)
contourf!(p2, X_date, periods, spectra_reversed, lw=0.25, c=cgrad(:Spectral, rev=true), 
          clims=(0, y_max), levels=10, fill=true, ylims=(0, y_ticks[end]), 
          xticks=(tm_tick, ticks), yticks=(y_ticks))

# Draw grid lines
for i in y_ticks
    hline!(p2, [i], lw=0.5, c=:white, label="")
end

for i in tm_tick
    vline!(p2, [i], lw=0.5, c=:white, label="")
end

p2 = hline!(p2, [15], lw=1, c=:red, ls=:dash, label="")

p1_p2 = plot(p1, p2, layout=(2,1), size=(1400, 800), framestyle=:box, suptitle = title)

try    
    infil_path = split(infil,"\\")
    site_name = infil_path[end-2]
    infil_name = infil_path[end]
    
    plot_file = ".\\Plots\\" * site_name * "_" * replace(infil_name, ".BVA" => "_contour_plot.png.png")

    # Output plot file name
    savefig(plot_file)
    println("\nPlot file saved as ",plot_file)
catch
    "Alert: Plot not saved!"
    flush(stdout)
end

display(p1_p2)

In [None]:
# Constants
SAMPLE_RATE = 2.56  # Define sample rate
bin_size = 0.005
low_freq_threshold = 1/15  # Example threshold for low-frequency energy calculation

# Pre-allocate percentages array
percentages = Float64[]

@time begin
    
    # Loop through each column of X_data

    for i in 1:size(X_data, 2)

        # Extract heave data and corresponding date
        heave = X_data[:, i]
        x_date = Dates.format(X_date[i], "yyyy-mm-dd HH:MM")

        # Compute the periodogram over a full 4096-point segment
        data_segment = heave[1:4096]
        psd_result = periodogram(data_segment, fs=SAMPLE_RATE)

        # Band-average the power spectral density (PSD) results
        averaged_freqs_fft, averaged_psd_fft = band_average(psd_result.freq, psd_result.power, bin_size)

        # Normalize PSD for energy calculations
        total_energy, low_freq_energy, percentage_low_freq_energy = calculate_energy(
            averaged_freqs_fft, normalize_psd(averaged_psd_fft), low_freq_threshold)

        # Set a flag for high % of low-frequency energy or possible spikes
        flag = percentage_low_freq_energy >= 5 ? "<=== anomaly" : percentage_low_freq_energy >= 1 ? "<--- spike" : ""

        # Print results
    #    println("$i $x_date - Low-Frequency Energy: $(round(percentage_low_freq_energy, digits=2))% $flag")
        @printf("%3d %s - Low-Frequency Energy: %.2f%% %s\n", i, x_date, percentage_low_freq_energy, flag)
        # Store percentage low-frequency energy
        push!(percentages, percentage_low_freq_energy)

    end
    
end

In [None]:
#plotly()

x = 1:length(X_date)
y = 0.005:0.005:1.28

title = site_name * " " * Dates.format.(X_date[1], "yyyy-mm-dd HH:MM") * " to " * Dates.format.(X_date[end], "yyyy-mm-dd HH:MM")

surface(x, y, spectra, size=(1400,1200), xlabel="Date", ylabel="Frequency (Hertz)", zlabel="Spectral Density (m²/Hz)", c=cgrad(:Spectral, rev=true), title=title, framestyle=:box, colorbar=false)

# Model programs

### Separate data into Good and Bad matricies for training the Model

In [None]:
using Tk


# Function to handle a date selection (left mouse click or double-click)
function handle_selection()
###########################
    
    selected_item = get_value(lb)  # Get current selection
    if selected_item !== nothing  # Check if something is selected
        # Add the selected date to the list if not already present
        if selected_item[1] ∉ bad_dates
            push!(bad_dates, selected_item[1])  # Add to bad_dates if not already there
        else
            deleteat!(bad_dates, findfirst(x -> x == selected_item[1], bad_dates))  # Deselect if already selected
        end
        println("Current bad_dates: ", selected_item[1])
    else
        println("No date selected!")
    end
    
end    # handle_selection()

# Callback function for the Exit button to process and close
function exit_callback()
########################
    
    global new_training_indicies_bad = findall(x -> x ∈ bad_dates, date_string)    # index of bad records in .BVA file
    global new_training_indicies_good = findall(x -> x ∉ bad_dates, date_string)

    # Populate matrices and vectors based on selections
    global new_training_data_good = new_training_data[:, new_training_indicies_good]
    global new_training_date_good = new_training_date[new_training_indicies_good]
    global new_training_data_bad = new_training_data[:, new_training_indicies_bad]
    global new_training_date_bad = new_training_date[new_training_indicies_bad]
    global new_training_labels = vcat(fill(:good, size(new_training_data_good, 2)), fill(:bad, size(new_training_data_bad, 2)))

    
    # Close the window after processing
    destroy(w)
    
end    # exit_callback()


###########################################################################################
###########################################################################################
###########################################################################################

# use currently selected .BVA file data
new_training_data = X_data
new_training_date = X_date

# build string vector of Dates
date_string = Dates.format.(new_training_date, "yyyy-mm-dd HH:MM")

# Initialize list of bad dates
global bad_dates = String[]

# Set up the Tk window
w = Toplevel("Select Bad Dates", 300, 600)
tcl("pack", "propagate", w, false)
f = Frame(w)
pack(f, expand=true, fill="both")

# Treeview setup to display dates
f1 = Frame(f)
lb = Treeview(f1, date_string)
scrollbars_add(f1, lb)
pack(f1, expand=true, fill="both")

# Button to finalize selection of bad dates
tcl("ttk::style", "configure", "TButton", foreground="blue", font="arial 16 bold")
exit_button = Button(f, "Exit")
pack(exit_button)

# Bind the Exit button to finalize selections and close the window
bind(exit_button, "command") do path
    exit_callback()  # Process the selections when exit button is clicked
end

# Bind double-click event to the Treeview
bind(lb, "<Double-1>") do event
    handle_selection()  # Handle selection on double-click
end

println("$(length(bad_dates)) bad records selected")
flush(stdout)

##new_training_labels = vcat(fill(:good, size(new_training_data_good, 2)), fill(:bad, size(new_training_data_bad, 2)));

In [None]:
print("$(length(bad_dates)) bad records selected")

In [None]:
new_training_data_bad

In [None]:
training_data_bad

### Save updated training data for file

In [None]:
using JLD2, Dates

##training_data_good = hcat(training_data_good, new_training_data_good)
training_data_bad = hcat(training_data_bad, new_training_data_bad)

println("Good data now ",string(size(training_data_good)[2])," records")
println("Bad  data now ",string(size(training_data_bad)[2])," records\n")

outfil = "BVA_training_data_updated_" * Dates.format(now(), "yyyy_mm_dd_HHMM") * ".JLD2" 

# Save the updated good and bad training data
@save outfil training_data_good training_data_bad

println("Updated training data saved successfully.")


### Recover earlier separated data from file (Note: does not include Model data)

In [None]:
using JLD2

using NativeFileDialog
using FilePathsBase

# Load the model and optimizer states from the JLD2 file
current_path = pwd() * "\\Training_data"
filterlist = "JLD2"
infil = pick_file(current_path; filterlist)

# Load all saved data and labels
##@load infil training_data training_date training_data_good training_date_good training_data_bad training_date_bad training_indicies_good training_indicies_bad training_labels # median_train std_train
@load infil training_data_good training_data_bad training_data_bad

println("Data and labels loaded successfully.")
println("Good data contains ",string(size(training_data_good)[2])," records")
println("Bad  data contains ",string(size(training_data_bad)[2])," records\n")

### Build Model using Good and Bad training data

In [None]:
using Flux, Statistics

function min_max_normalize_matrix(X)
####################################
    
    min_vals = minimum(X, dims=1)  # Compute min for each column
    max_vals = maximum(X, dims=1)  # Compute max for each column
    
    return((X .- min_vals) ./ (max_vals .- min_vals))
    
end    # min_max_normalize_matrix()


function z_score_normalize_matrix(X)
####################################
    
    mean_vals = mean(X, dims=1)  # Mean for each column
    std_vals = std(X, dims=1)    # Standard deviation for each column
    
    return((X .- mean_vals) ./ std_vals)
    
end    # z_score_normalize_matrix()


function pad_or_truncate(record, target_length=4608)
####################################################

    length(record) < target_length ? vcat(record, zeros(Float32, target_length - length(record))) :
                                     record[1:target_length]

end    # pad_or_truncate()


function get_heave(Data, f23_df)
################################
    
    heave_array = []
    training_date = []
    
    for idx in 1:nrow(f23_df)

        if !isnothing(f23_df.Data_vector[idx])
    
            start_date, start_val, end_val = get_start_end_dates(f23_df,idx)
            if start_val > 0
                print(".")
                heave, north, west = get_hnw(Data,start_val,end_val)

                # ensure we have 4608 data points
                push!(heave_array,pad_or_truncate(heave, 4608))
                push!(training_date,start_date)
            end

        end
    
    end

    return(hcat(heave_array...), training_date)

end    # get_heave()


function calc_reconstruction_errors(training_data_float32, model)
####################################################
    
    reconstruction_errors = Float32[]
    
    for record in eachcol(training_data_float32)  # Each record is now a column with 14 features
        reconstructed_record = model(record)  # Pass the record to the autoencoder
        error = mean((reconstructed_record .- record).^2)  # Calculate the reconstruction error
        push!(reconstruction_errors, error)  # Store the error
    end
    
    return(reconstruction_errors)

end    # calc_reconstruction_errors()

####################################################################
####################################################################
####################################################################

# Define autoencoder model
initial_model = Chain(
    Dense(4608, 128, relu),  # Encoder
    Dense(128, 64, relu),    # Bottleneck
    Dense(64, 128, relu),    # Decoder
    Dense(128, 4608)         # Output layer, reconstructs input
)

refined_model = Chain(
    Dense(4608, 256, relu),   # Encoder start
    Dense(256, 128, relu),
    Dense(128, 64, relu),     # bottleneck layer 64 nodes
    Dense(64, 128, relu),
    Dense(128, 256, relu),    # Decoder end
    Dense(256, 4608)          # Output layer, reconstructs input
)

# Define the loss function (e.g., Mean Squared Error for reconstruction)
loss(x) = Flux.mse(refined_model(x), x)

#==
Optimizer: Adam with default parameters (learning rate, etc.)
-------------------------------------------------------------
In Flux.jl (and most deep learning frameworks), optimizers like Adam come with predefined default values for certain parameters, such as:

Learning Rate (α): Controls the step size at each iteration to minimize the loss function. The default is usually set at 0.001.
β1 and β2: Decay rates for the moving averages of the gradient and squared gradient, respectively. The default values are typically:
    β1 = 0.9
    β2 = 0.999
Epsilon (ε): A small constant added to prevent division by zero. This is usually set to 1e-8.
==#

opt = Adam()
#opt = AMSGrad(0.001)
#opt = Momentum(0.01, 0.9)

# Concatenate good and bad data
training_data_combined = hcat(training_data_good, training_data_bad)

# Normalize the combined data
training_data_normalized = min_max_normalize_matrix(training_data_combined)
    
# Convert WSE data to Float32
training_data_float32 = Float32.(training_data_normalized)

@time begin
    
    # Calculate the reconstruction errors (optional)
    reconstruction_errors_model = calc_reconstruction_errors(training_data_float32, refined_model)
    
    # Prepare data for training
    data = Iterators.repeated((training_data_float32,), 100)  # data iteration for training
    
    # Train the model
    Flux.train!(loss, Flux.params(refined_model), data, opt)
    
    println("Done!")
    
end

### Build Model using mini-batches

In [None]:
using Flux, Random, Statistics

# Define mini-batch size
batch_size = 64

# Function to create mini-batches
function create_mini_batches(data, batch_size)
##############################################
    
    num_samples = size(data, 2)  # Number of columns in `data` (each column is a record)
    shuffle_indices = randperm(num_samples)  # Shuffle indices for randomness
    mini_batches = []
    for i in 1:batch_size:num_samples
        end_idx = min(i + batch_size - 1, num_samples)
        push!(mini_batches, data[:, shuffle_indices[i:end_idx]])
    end
    
    return(mini_batches)
    
end    # create_mini_batches()


# Training function with mini-batches
function train_model_with_mini_batches(model, data, loss_fn, opt, num_epochs=10, batch_size=64)
###############################################################################################
    
    for epoch in 1:num_epochs
        mini_batches = create_mini_batches(data, batch_size)
        for mini_batch in mini_batches
            Flux.train!(loss_fn, Flux.params(model), [(mini_batch,)], opt)
        end
        println("Completed epoch $epoch")
    end
    
end    # train_model_with_mini_batches()

####################################################################
####################################################################
####################################################################

# Define the refined autoencoder model (unchanged)
refined_model = Chain(
    Dense(4608, 256, relu),
    Dense(256, 128, relu),
    Dense(128, 32, relu),
    Dense(32, 128, relu),
    Dense(128, 256, relu),
    Dense(256, 4608)
)

# Define the loss function
loss(x) = Flux.mse(refined_model(x), x)

# Define the optimizer
opt = Adam()

# Prepare normalized training data (as before)
training_data_combined = training_data_good #hcat(training_data_good, training_data_bad)
training_data_normalized = min_max_normalize_matrix(training_data_combined)
training_data_float32 = Float32.(training_data_normalized)

# Train the model using mini-batches
num_epochs = 10  # Define the number of epochs for training
train_model_with_mini_batches(refined_model, training_data_float32, loss, opt, num_epochs, batch_size)

println("Training complete with mini-batch processing.")


### Run the model against data in the selected .BVA file

In [None]:
# Calculate mean and std for each column (record) in the training data (training_data_good)
mean_train = mean(training_data_good, dims=1)  # 1×number of records: mean of each record (column) in training_data_good
std_train = std(training_data_good, dims=1)    # 1×number of records: std of each record (column) in training_data_good

X_data_32 = Float32.(X_data)

# Now normalize X_data_32 (4608×191) using the first 191 records of mean_train and std_train
# Transpose mean_train and std_train to be broadcasted correctly
mean_train_selected = mean_train[:, 1:size(X_data)[2]]  # determine the number of columns in X_data and select this number of values from mean_train
std_train_selected = std_train[:, 1:size(X_data)[2]]    

# Normalize X_data_32
X_new_normalized = Float32.((X_data_32 .- mean_train_selected) ./ std_train_selected)

# Step 1: Make predictions using the trained model
predicted_X_data = refined_model(X_new_normalized)

# Step 2: Calculate reconstruction error (MSE)
reconstruction_error = sum((X_new_normalized .- predicted_X_data) .^ 2, dims=1)

# Flatten the reconstruction error matrix into a 1D vector
reconstruction_error_vector = vec(reconstruction_error)

# Step 3: Set a threshold for outlier detection (e.g., 99th percentile)
threshold = quantile(reconstruction_error_vector, 0.995)

# Step 4: Identify outliers based on the threshold
outliers = findall(reconstruction_error .> threshold)
global outlier_indices = [idx[2] for idx in outliers]

# Step 5: Get the corresponding dates for the outliers
outlier_dates = X_date[outlier_indices]

if !isempty(outlier_dates)

    # Print the outliers
    println("Outliers detected at the following dates: ")
    for ii in outlier_dates
        date_string = Dates.format(ii, "yyyy-mm-dd HH:MM")
        println("    ", date_string)
    end
    
else

    println("No outliers detected")
    
end

### Build the Model using a hybrid approach

In [None]:
using Flux, Statistics

# (Keep existing functions: min_max_normalize_matrix, z_score_normalize_matrix, pad_or_truncate, get_heave)

function min_max_normalize_matrix(X)
####################################
    
    min_vals = minimum(X, dims=1)  # Compute min for each column
    max_vals = maximum(X, dims=1)  # Compute max for each column
    
    return((X .- min_vals) ./ (max_vals .- min_vals))
    
end    # min_max_normalize_matrix()


function calc_reconstruction_errors(data_matrix, model)
#######################################################
    
    reconstruction_errors = Float32[]
    for record in eachcol(data_matrix)
        reconstructed_record = model(record)
        error = mean((reconstructed_record .- record).^2)
        push!(reconstruction_errors, error)
    end
    
    return(reconstruction_errors)
    
end    # calc_reconstruction_errors()


####################################################################
####################################################################
####################################################################

# Define your refined autoencoder model as you have it
refined_model = Chain(
    Dense(4608, 256, relu),
    Dense(256, 128, relu),
    Dense(128, 32, relu),
    Dense(32, 128, relu),
    Dense(128, 256, relu),
    Dense(256, 4608)
)

# Concatenate and normalize the training data
training_data_combined = hcat(training_data_good, training_data_bad)
training_data_normalized = min_max_normalize_matrix(training_data_combined)
training_data_float32 = Float32.(training_data_normalized)

@time begin

    # Train the model
    loss(x) = Flux.mse(refined_model(x), x)
    opt = Adam()
    
    data = Iterators.repeated((training_data_float32,), 100)
    Flux.train!(loss, Flux.params(refined_model), data, opt)
    println("Model training complete.")

end

In [None]:
using Sockets

#global X_data = Matrix{Float32}(undef, 0, 0)

hostname = gethostname()
println("The name of the computer is: ", hostname)

if hostname == "QUEENSLAND-BASIN"
    
    display("text/html", "<style>.container { width:100% !important; }</style>")
    
else
    
    display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))    
    
end


REC_LENGTH = 4608       # Number of WSE's in a Mk4 30-minute record
SAMPLE_FREQUENCY = 2.56 # Mk4 sample frequency in Hertz
SAMPLE_LENGTH = 1800    # record length in seconds
SAMPLE_RATE = Float64(1/SAMPLE_FREQUENCY) # sample spacing in seconds

#########################################################################################################################
##    confidence_interval = 2.576  # corresponds to a 99% confidence interval (for a normal distribution)
##    confidence_interval = 3.0    # corresponds to a 99.73% confidence interval (for a normal distribution)    
##    confidence_interval = 3.29   # corresponds to a 99.9% confidence interval (for a normal distribution)
#########################################################################################################################

# Widen screen for better viewing
display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))
display("text/html", "<style>.container { width:100% !important; }</style>")

initial_path = "F:\\Card Data\\"
infil = pick_file(initial_path)

f23_df, Data = get_hex_array(infil)

f23_df = get_matches(Data, f23_df)

# remove those vectors from F23 df that are not located in the Data vector df
f23_df = f23_first_row_check(f23_df)

X_data, X_date = get_heave(Data, f23_df);

X_data = Float32.(X_data)

println(string(length(X_date))," records processed.\n")
println("\nNow plot heave for each record to see suspect data!")
flush(stdout)

### Run the hybrid model against data in the selected .BVA file

In [None]:
# Calculate reconstruction errors for good and bad training data separately
errors_good = calc_reconstruction_errors(Float32.(min_max_normalize_matrix(training_data_good)), refined_model)
errors_bad = calc_reconstruction_errors(Float32.(min_max_normalize_matrix(training_data_bad)), refined_model)

# Weighting errors for bad data
bad_weight_factor = 2.0  # Apply more weight to bad data

# Weighted reconstruction errors
weighted_errors_bad = errors_bad .* bad_weight_factor

# Set separate thresholds
good_threshold = quantile(errors_good, 0.95)  
bad_threshold = quantile(weighted_errors_bad, 0.99)

println("Good data threshold:", good_threshold)
println("Bad data threshold:", bad_threshold)

# Normalize new data
X_data_32 = Float32.(X_data)
mean_train_selected = mean(training_data_good, dims=1)[:, 1:size(X_data)[2]]
std_train_selected = std(training_data_good, dims=1)[:, 1:size(X_data)[2]]
X_new_normalized = Float32.((X_data_32 .- mean_train_selected) ./ std_train_selected)

# Make predictions using the trained model
#==
Note: The reconstruction error is the difference between the original data and the reconstructed data. 
      This error reflects how well the model can replicate the original data. 
      A low reconstruction error suggests the model has captured the structure of the data well, 
      while a high reconstruction error suggests that the data is unusual or anomalous.
==#
predicted_X_data = refined_model(X_new_normalized)
reconstruction_error = sum((X_new_normalized .- predicted_X_data) .^ 2, dims=1)
reconstruction_error_vector = vec(reconstruction_error)    # convert reconstruction_error matrix to vector

# Normalize reconstruction errors to match scaling with good and bad thresholds
normalized_reconstruction_error = min_max_normalize_matrix(reconstruction_error_vector)

threshold = mean(normalized_reconstruction_error) + 3 * std(normalized_reconstruction_error)
outliers = findall(normalized_reconstruction_error .> threshold)
outlier_indices = [idx for idx in outliers]

# Apply dual-threshold for outlier detection
#outliers = findall(reconstruction_error_vector .> bad_threshold)
uncertain = findall(x -> good_threshold < x <= bad_threshold, normalized_reconstruction_error)
#==
Note: "uncertain" generally refers to data points that the model finds difficult to classify as 
      either clearly "inliers" (normal data) or "outliers" (abnormal data). 
      These uncertain points fall in a "gray zone," meaning they don't strongly resemble typical 
      inliers but also don't fully meet the model's criteria for outliers.
==#
# Get dates for outliers and uncertain data points
outlier_dates = X_date[outlier_indices]
uncertain_dates = X_date[[idx for idx in uncertain]]

# Print out results
if !isempty(outlier_dates)
    println("Outliers detected at the following dates:")
    for date in outlier_dates
        println("    ", Dates.format(date, "yyyy-mm-dd HH:MM"))
    end
else
    println("No outliers detected.")
end

if !isempty(uncertain_dates)
    println("Uncertain data points detected at the following dates:")
    for date in uncertain_dates
        println("    ", Dates.format(date, "yyyy-mm-dd HH:MM"))
    end
else
    println("No uncertain data points detected.")
end

### Description of Model using a hybrid approach

In [None]:
plot(normalized_reconstruction_error, size=(1200,600))
hline!([threshold], label="Threshold")
hline!([good_threshold], label="Good Threshold", ls=:dot)
hline!([bad_threshold], label="Bad Threshold", ls=:dash)


### Build the Model using a hybrid approach with overlapping segments

In [None]:
using Flux, Statistics

# (Keep existing functions: min_max_normalize_matrix, z_score_normalize_matrix, pad_or_truncate, get_heave)

# Function to create overlapping segments for training/testing data
function create_overlapping_segments(data, segment_length::Int, overlap::Int)
############################################################################################
    
    step = segment_length - overlap
    num_segments = max(1, (length(data) - segment_length) ÷ step + 1)
    segments = [data[i:i + segment_length - 1] for i in 1:step:length(data)-segment_length+1]
  
    return(hcat(segments...))
    
end    # create_overlapping_segments()


function calc_reconstruction_errors(data_matrix, model)
#######################################################
    
    reconstruction_errors = Float32[]
    for record in eachcol(data_matrix)
        reconstructed_record = model(record)
        error = mean((reconstructed_record .- record).^2)
        push!(reconstruction_errors, error)
    end
    
    return(reconstruction_errors)
    
end    # calc_reconstruction_errors()


# Define your refined autoencoder model as you have it
refined_model = Chain(
    Dense(4608, 256, relu),
    Dense(256, 128, relu),
    Dense(128, 64, relu),
    Dense(64, 128, relu),
    Dense(128, 256, relu),
    Dense(256, 4608)
)

# Generate new training data with 512-length segments
segment_length = 512  # Update segment length to match the new model input size
overlap = 256         # Set overlap as needed for your task

# Apply this to both good and bad data
training_data_good_segments = create_overlapping_segments(training_data_good, segment_length, overlap)
training_data_bad_segments = create_overlapping_segments(training_data_bad, segment_length, overlap)

In [None]:
function concatenate_segments(segments, group_size::Int)
    num_segments = size(segments, 2) ÷ group_size
    concatenated = [hcat(segments[:, i*group_size+1:(i+1)*group_size]...) for i in 0:num_segments-1]
    return hcat(concatenated...)
end

# Apply to training data
training_data_segments = hcat(training_data_good_segments, training_data_bad_segments)
training_data_concatenated = concatenate_segments(training_data_segments, 9)

# Normalize the segmented data
training_data_segments_normalized = min_max_normalize_matrix(training_data_concatenated)

# Convert to Float32
training_data_segments_float32 = Float32.(training_data_segments_normalized)

# Train the model
loss(x) = Flux.mse(refined_model(x), x)
opt = Adam()
data = Iterators.repeated((training_data_segments_float32,), 100)
# Reshape data to have 4608 features per sample (columns) and as many samples as possible (rows)
data = reshape(data, 4608, :)

Flux.train!(loss, Flux.params(refined_model), data, opt)
println("Model training complete.")

# Calculate reconstruction errors for good and bad training data separately
errors_good = calc_reconstruction_errors(Float32.(min_max_normalize_matrix(training_data_good_segments)), refined_model)
errors_bad = calc_reconstruction_errors(Float32.(min_max_normalize_matrix(training_data_bad_segments)), refined_model)

# Set separate thresholds
good_threshold = quantile(errors_good, 0.99)  # e.g., 95th percentile of good errors
bad_threshold = quantile(errors_bad, 0.995)    # e.g., 95th percentile of bad errors

println("Good data threshold:", good_threshold)
println("Bad data threshold:", bad_threshold)

# Normalize new data
X_data_32 = Float32.(X_data)
mean_train_selected = mean(training_data_good_segments, dims=1)[:, 1:size(X_data)[2]]
std_train_selected = std(training_data_good_segments, dims=1)[:, 1:size(X_data)[2]]
X_new_normalized = Float32.((X_data_32 .- mean_train_selected) ./ std_train_selected)

In [None]:
# Make predictions using the trained model
##predicted_X_data = refined_model(X_new_normalized)
function segment_data(data_matrix, segment_length)
    n_segments = div(size(data_matrix, 1), segment_length)
    return reshape(data_matrix[1:(n_segments * segment_length), :], segment_length, :)
end
    
X_new_segmented = segment_data(X_new_normalized, 512)

predicted_X_data = refined_model(X_new_segmented)

reconstruction_error = sum((X_new_segmented .- predicted_X_data) .^ 2, dims=1)
reconstruction_error_vector = vec(reconstruction_error)

threshold = mean(reconstruction_error_vector) + 3 * std(reconstruction_error_vector)
outliers = findall(reconstruction_error_vector .> threshold)

# Apply dual-threshold for outlier detection
outliers = findall(reconstruction_error_vector .> bad_threshold)
uncertain = findall(x -> good_threshold < x <= bad_threshold, reconstruction_error_vector)

# Get dates for outliers and uncertain data points
outlier_dates = X_date[outliers]
uncertain_dates = X_date[uncertain]

# Print out results
if !isempty(outlier_dates)
    println("Outliers detected at the following dates:")
    for date in outlier_dates
        println("    ", Dates.format(date, "yyyy-mm-dd HH:MM"))
    end
else
    println("No outliers detected.")
end

if !isempty(uncertain_dates)
    println("Uncertain data points detected at the following dates:")
    for date in uncertain_dates
        println("    ", Dates.format(date, "yyyy-mm-dd HH:MM"))
    end
else
    println("No uncertain data points detected.")
end

### Plot records with suspect data (as identified by the model)

In [None]:
# Function for dynamic threshold based on wave amplitudes
function dynamic_z_score_threshold(heave, base_threshold=3.0, k=0.5)
    ####################################################################
    
    # Calculate the amplitude of heave
    amplitude = abs.(heave)
    
    # Calculate mean and standard deviation of the amplitudes
    mean_amplitude = mean(amplitude)
    std_amplitude = std(amplitude)
    
    # Calculate the range of amplitudes
    range_amplitude = maximum(amplitude) - minimum(amplitude)
    
    # Calculate dynamic threshold
    dynamic_threshold = base_threshold * (1 + k * (range_amplitude / std_amplitude))
    
    # Cap the dynamic threshold
    max_threshold = 3.99  # or any other maximum threshold you prefer
    dynamic_threshold = min(dynamic_threshold, max_threshold)
    
    return dynamic_threshold
    
end  # dynamic_z_score_threshold()


for ii ∈ outlier_indices

    # Initialize variables
    start_time = X_date[ii]
    global heave = X_data[:, ii]
    end_time = start_time + Minute(30)
    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)

    # Plot initialization
    p1 = plot(size=(2000, 400), dpi=100, framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
        legend=:topright, xtickfont=font(8), ytickfont=font(8), bottommargin = 10Plots.mm,
        grid=true, gridlinewidth=0.125, gridstyle=:dot, gridalpha=1)
    
    tm_tick = range(start_time, end_time, step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")
    
    # Calculate dynamic confidence interval
    global confidence_interval = dynamic_z_score_threshold(heave)

    # Identify outliers using modified z-score
    z_score_indices, mod_z_scores = modified_z_score(heave, confidence_interval)
    if !isempty(z_score_indices)
        scatter!(p1, xvals[z_score_indices], heave[z_score_indices], 
            markersize=4, markerstrokecolor=:red, markerstrokewidth=1, 
            markercolor=:white, markershape=:utriangle, label="")
    end

    # Plot confidence limits
    confidence_limits = calc_confidence_limits(heave, confidence_interval)
    hline!(p1, [confidence_limits[1], confidence_limits[2]], color=:red, lw=1, linestyle=:dash, label="")

    # Plot heave data
    plot!(p1, xvals, heave, xlims=(xvals[1], xvals[end]), lw=0.5, lc=:blue, alpha=0.5, 
        xticks=(tm_tick, ticks), label="")

    # Annotate plot with the number of outliers and confidence interval
    num_outliers = length(z_score_indices)
    suspect_string = string("  ", string(ii)," ",Dates.format(start_time, "yyyy-mm-dd HH:MM"), " - ", num_outliers, " Possible outliers") # using Confidence Interval of ", 
##        @sprintf("%.2f", confidence_interval))
    annotate!(p1, xvals[1], maximum(heave) * 0.9, text(suspect_string, :left, 10, :blue))

    display(p1)

end

In [None]:
function modified_z_score(data, threshold)
##########################################
    
    med = median(data)
    mad = median(abs.(data .- med))
    mod_z_scores = 0.6745 * (data .- med) ./ mad
    outlier_indices = findall(x -> abs(x) > threshold, mod_z_scores)
    
    return(outlier_indices, mod_z_scores)
    
end    # modified_z_score()

heave = X_data[:,1]
# Calculate dynamic confidence interval
    confidence_interval = dynamic_z_score_threshold(heave)

    # Identify outliers using modified z-score
    outlier_indices, mod_z_scores = modified_z_score(heave, confidence_interval)


In [None]:
mod_z_scores

### Save separated data to file (Note: does not include Model data)

In [None]:
using JLD2, Dates

outfil = "BVA_trainng_data_"*Dates.format(now(), "yyyy_mm_dd_HHMM")*".JLD2" 
# Save all relevant data and labels
@save outfil training_data_good training_data_bad
println("Data and labels saved successfully.")


### Recover earlier separated data from file (Note: does not include Model data)

In [None]:
using JLD2

using NativeFileDialog: pick_file

# Load the model and optimizer states from the JLD2 file
infil = pick_file()

# Load all saved data and labels
##@load infil training_data training_date training_data_good training_date_good training_data_bad training_date_bad training_indicies_good training_indicies_bad training_labels # median_train std_train
@load infil training_data_good training_data_bad training_data_bad

println("Data and labels loaded successfully.")


In [None]:
training_data = Y_train 
training_date = Y_date 
training_data_good = Y_train_good
training_date_good = Y_date_good 
training_data_bad = Y_train_bad 
training_date_bad = Y_date_bad 
training_indicies_good = good_indices
training_indicies_bad = bad_indices 
training_labels = labels

### Recover earlier separated Model data from file

In [None]:
using JLD2, Flux
using NativeFileDialog: pick_file

# Load the model and optimizer states from the JLD2 file
infil = pick_file()

# Load model state, optimizer, and data/labels
loaded_model_state, opt_state, X_train, X_date, X_train_good, X_date_good, X_train_bad, X_date_bad, labels = 
    JLD2.@load infil model_state opt_state X_train X_date X_train_good X_date_good X_train_bad X_date_bad labels

# Reconstruct the model
model = Chain(
    Dense(4608, 128, relu),  # Encoder
    Dense(128, 64, relu),    # Bottleneck
    Dense(64, 128, relu),    # Decoder
    Dense(128, 4608)         # Output layer, reconstructs input
)

# Load the model parameters
Flux.loadmodel!(model, model_state)

println("Model, optimizer, data, and labels loaded successfully.")


### Append new data to older data

In [None]:
## Append new data
X_train = hcat(X_train, new_X_train)
X_date = vcat(X_date, new_X_date)

# Append new labels if they exist
labels_good = vcat(labels_good, new_labels_good)  # or update as needed
labels_bad = vcat(labels_bad, new_labels_bad)      # or update as needed

# Save the updated data and labels
@save "processed_data_with_labels.jld2" X_train X_date X_train_good X_date_good X_train_bad X_date_bad selected_indices labels
println("Updated data and labels saved successfully.")


### Save model, optimiser, and heave data to file

In [None]:
using Flux, JLD2, Dates

# Save the model state, optimizer state, and relevant data/labels
model_state = Flux.state(model)  # Get the model's parameters
opt = opt # Flux.setup(Adam(), model)

outfil = "BVA_model_and_data_"*Dates.format(now(), "yyyy_mm_dd_HHMM")*".JLD2" 

# Save model and optimizer and normalised wave data to a JLD2 file
@save outfil model_state opt X_train X_date X_train_good X_date_good X_train_bad X_date_bad labels

println("Data and model saved successfully to ",outfil)


### Recover model, optimiser, and heave data to file

In [None]:
using JLD2, Flux
using NativeFileDialog: pick_file

# Define your model architecture here (example)
function create_model()
    return Chain(
        Dense(10, 5, relu),
        Dense(5, 1)
    )
end


# Load the model and optimizer states from the JLD2 file
infil = pick_file()

# Change the file extension to uppercase
infil = replace(infil, ".jld2" => ".JLD2")

println("File selected: $infil")  # Print the selected file path

try
    # Load model state, optimizer, and data/labels
    model_state, opt_state, X_train, X_date, X_train_good, X_date_good, X_train_bad, X_date_bad, labels = 
        @load infil model_state opt_state X_train X_date X_train_good X_date_good X_train_bad X_date_bad labels

    # Reconstruct the model from the architecture
    model = create_model()  # Call the function to create the model
    Flux.load!(model, model_state)  # Load the model parameters

    # Set up the optimizer
    opt = Flux.setup(Adam(), model)  # Reinitialize the optimizer with the model
    Flux.load!(opt, opt_state)  # Load the optimizer state

    println("Model, optimizer, and data loaded successfully.")
catch e
    println("Error loading data: ", e)
end

In [None]:
using Flux, JLD2

# Load the model state, optimizer state, and relevant data/labels
model_state, opt_state, X_train, X_date, X_train_good, X_date_good, X_train_bad, X_date_bad, labels = @load outfil

# Reconstruct the model and optimizer from the loaded states
model = Flux.loadmodel(model_state)  # Use your model architecture to reconstruct it
opt = Flux.setup(Adam(), model)  # Reinitialize the optimizer with the model

println("Model, optimizer, and data loaded successfully.")


### Select records that will NOT be uploaded to model

In [None]:
using Tk

# Make a copy of X_train in case want to roll-back
X_train_old = X_train

# Convert DateTimes into strings
all_dates = Dates.format.(X_date, "yyyy-mm-dd HH:MM")

# Initialize selection window
w = Toplevel("Select Date", 235, 400)
tcl("pack", "propagate", w, false)
f = Frame(w)
pack(f, expand=true, fill="both")

f1 = Frame(f)
lb = Treeview(f1, all_dates)
scrollbars_add(f1, lb)
pack(f1, expand=true, fill="both")

# Style button
tcl("ttk::style", "configure", "TButton", foreground="blue", font="arial 16 bold")
b = Button(f, "Ok")
pack(b)

# Global array to store selected dates
global bad_array = []

# Collect dates when the button is pressed
bind(b, "command") do path
    file_choice = get_value(lb)
    if !isempty(file_choice)
        push!(bad_array, file_choice[1])
        println("Added to removal list: ", file_choice[1])
    end
end

# Function to close window and process columns
function finalize_selection(args...)
************************************
    
    destroy(w)  # Close Tk window

    # Find indices of bad_dates in X_date
    bad_cols = findall(x -> x in bad_array, all_dates)

    # Remove columns from X_train based on `bad_cols`
    global X_train = X_train[:, setdiff(1:size(X_train, 2), bad_cols)]
    println("Removed columns based on selection.")
    
end    # finalize_selection()

# Bind finalize function to window close event
tcl("wm", "protocol", w, "WM_DELETE_WINDOW", finalize_selection)


### Initial Code for Training an Autoencoder in Julia (using Flux.jl)

In [None]:
using Flux, Statistics

function min_max_normalize_matrix(X)
    min_vals = minimum(X, dims=1)  # Compute min for each column
    max_vals = maximum(X, dims=1)  # Compute max for each column
    return (X .- min_vals) ./ (max_vals .- min_vals)
end


function z_score_normalize_matrix(X)
    mean_vals = mean(X, dims=1)  # Mean for each column
    std_vals = std(X, dims=1)    # Standard deviation for each column
    return (X .- mean_vals) ./ std_vals
end


function pad_or_truncate(record, target_length=4608)
####################################################
#==    
    if length(record) < target_length
        # Pad with zeros (or any other value you prefer)
        return vcat(record, zeros(Float32, target_length - length(record)))
    elseif length(record) > target_length
        # Truncate to the target length
        return record[1:target_length]
    else
        return record
    end
==#
    length(record) < target_length ? vcat(record, zeros(Float32, target_length - length(record))) :
                                     record[1:target_length]

end    # pad_or_truncate()


function get_heave(Data, f23_df)
################################
    
    heave_array = []
    X_date = []
    
    for idx in 1:nrow(f23_df)

        if !isnothing(f23_df.Data_vector[idx])
    
            start_date, start_val, end_val = get_start_end_dates(f23_df,idx)
            if start_val > 0
                print(".")
                heave, north, west = get_hnw(Data,start_val,end_val)

                # ensure we have 4608 data points
                push!(heave_array,pad_or_truncate(heave, 4608))
                push!(X_date,start_date)
            end

        end
    
    end

    return(hcat(heave_array...), X_date)

end    # get_heave()


function calc_reconstruction_errors(X_train_float32, model)
####################################################
    
    reconstruction_errors = Float32[]
    
    for record in eachcol(X_train_float32)  # Each record is now a column with 14 features
        reconstructed_record = model(record)  # Pass the record to the autoencoder
        error = mean((reconstructed_record .- record).^2)  # Calculate the reconstruction error
        push!(reconstruction_errors, error)  # Store the error
    end
    
    return(reconstruction_errors)

end    # calc_reconstruction_errors()

####################################################################
####################################################################
####################################################################
@time begin
# Define autoencoder model
model = Chain(
    Dense(4608, 128, relu),  # Encoder
    Dense(128, 64, relu),    # Bottleneck
    Dense(64, 128, relu),    # Decoder
    Dense(128, 4608)         # Output layer, reconstructs input
)

# Define the loss function (e.g., Mean Squared Error for reconstruction)
loss(x) = Flux.mse(model(x), x)

# Optimizer: Adam with default parameters (learning rate, etc.)
opt = Adam()
   
X_train, X_date = get_heave(Data, f23_df)

# Normalize your training data
##X_train_normalized = normalize_records(X_train)
X_train_normalized = min_max_normalize_matrix(X_train)
    
# Convert WSE data to Float32
X_train_float32 = Float32.(X_train_normalized)

# calculate the reconstruction_errors
reconstruction_errors_model = calc_reconstruction_errors(X_train_float32, model)

# Use the converted data for training
data = Iterators.repeated((X_train_float32,), 100)  # Example of data iteration for training

# Train the model
Flux.train!(loss, Flux.params(model), data, opt)
end    # @time

println("Done!")

### Save model, optimiser, and normalised heave data to file

In [None]:
using Flux, JLD2

# Save the model and optimizer
model_state = Flux.state(model)
opt_state = opt # Flux.setup(Adam(), model)

outfil = "HVA_model_"*Dates.format(now(), "yyyy_mm_dd_HHMM")*".jld2" 

# Save model and optimizer and normalised wave data to a JLD2 file
jldsave(outfil; model_state, opt_state, X_train_float32)


### Recover saved model, optimiser, and normalised heave data from file

In [None]:
using JLD2, Flux, Tk

# Load the model and optimizer states from the JLD2 file
infil = pick_file()
loaded_data = jldopen(infil, "r") do file
    old_model_state = file["model_state"]  # Load model state
    old_opt_state = file["opt_state"]      # Load optimizer state
    old_X_train_float32 = file["X_train_float32"]  # Load the previous X_train_float32 data
    return (old_model_state, old_opt_state, old_X_train_float32) # Return all states and data
end

old_model_state, old_opt_state, old_X_train_float32 = loaded_data

# Define the old model architecture
old_model = Chain(
    Dense(4608, 128, relu),  # Encoder
    Dense(128, 64, relu),    # Bottleneck
    Dense(64, 128, relu),     # Decoder
    Dense(128, 4608)          # Output layer, reconstructs input
)

# Restore model parameters from the loaded state
for (layer, state) in zip(old_model.layers, old_model_state[:layers])
    layer.weight .= state.weight   # Assign saved weights
    layer.bias .= state.bias       # Assign saved biases
end

# Restore the optimizer state directly
old_opt = old_opt_state;  # Assign the loaded optimizer state


### Append new data to existing model and retrain

In [None]:
using Flux
using JLD2

infil = pick_file()

# Load the existing model and optimizer
loaded_model, opt = JLD2.@load infil model opt

# Load new records and prepare the data
new_X_train, _ = get_heave(Data, f23_df)  # Replace with your new data fetching function
new_X_train_normalized = normalize_records(new_X_train)
new_X_train_float32 = Float32.(new_X_train_normalized)

# Combine with the previous training data (if applicable)
# You can concatenate with previous training data if desired
X_combined = hcat(old_X_train_float32, new_X_train_float32)

# Define the loss function again
loss(x) = Flux.mse(loaded_model(x), x)

# Prepare the data for training (iterating over the new combined data)
data = Iterators.repeated((X_combined,), 100)

# Train the model on the new data
Flux.train!(loss, Flux.params(loaded_model), data, opt)

outfil = infil

# Save model and optimizer and normalised wave data to a JLD2 file
jldsave(outfil; model_state, opt_state, X_train_float32)