# Locate possible outliers in Heave, North, and West data in .RDT files

### Load required packages

In [None]:
# List of packages (and their functions) used in the modules below
using DataFrames: DataFrame, ncol, nrow
using Dates: Day, Date, Dates, DateTime, Hour, Microsecond, Minute, Month, now, Time, unix2datetime, Year
#using FilePathsBase
using Flux: Adam, Chain, Dense, Flux, mse, params, relu, train!
using Glob: glob
using JLD2: @load, @save
using NativeFileDialog: pick_file, pick_folder
using Plots:  annotate!, font, hline!, hspan!, plot, Plots, plotly, plot!, scatter!, text, vline!, xlims, ylims, @layout
using Printf: @sprintf
using Sockets: gethostname
using Statistics: mean, median, quantile, std
using Tk: Button, Frame, Tk, Toplevel, Treeview, bind, destroy, get_value, pack, scrollbars_add, tcl

include(".\\Mk3_model_functions.jl");    # this contains the functions called by the modules below

println("Loading packages completed")

hostname = gethostname()
##println("The name of the computer is: ", hostname)

if hostname == "QUEENSLAND-BASIN"
    
    display("text/html", "<style>.container { width:100% !important; }</style>")
    initial_path = "E:\\Card Data\\"
    
else
    
    display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))    
    initial_path = "F:\\Card Data\\"
    
end;

### Locate and display records with GPS errors as flagged by Datawell

In [None]:
using DataFrames: DataFrame
using Dates: DateTime, year
using DSP: welch_pgram, freq, power, hanning
using Glob: glob
using NativeFileDialog: pick_file

import DataFrames: Not, select!

# Widen screen for better viewing
display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))

# Helper function to convert data without creating temporary strings
function parse_hex(data_array, idx)
###################################
    
    UInt16(data_array[idx]) << 8 | UInt16(data_array[idx+1])
    
end    # parse_hex()


# Function to check if the LSB is 1 (GPS interference)
function check_gps_flag(north_row)
##################################
    
    gps_flag_row = [((n & 0x1) == 1) ? 1 : 0 for n in north_row]
    
    return(gps_flag_row)
    
end    # check_gps_flag()
    

using Plots: plot!, vline
using Dates: Microsecond, Minute

# plot the selected displacement and GPS errors
function plot_pp(xvals, errors, color, X_data, ii, num, tm_tick, ticks)
#######################################################################
    
    pp = vline(xvals[errors], label="", lc=:red, ls=:dot, ylims=extrema(X_data[:,ii,num]) .* 1.1)
    pp = plot!(xvals, X_data[:,ii,num], lc=color, label="", xlims=(xvals[1], xvals[end]), xticks=(tm_tick, ticks))

    return(pp)

end    # plot_pp()
    

# plot heave, north, and west displacements for selected date
function do_plots(ii, X_date, X_data, GPS_errors)
#################################################
    
    start_time = X_date[ii]
    errors = findall(x -> x == 1, GPS_errors[:, ii, 1])

    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)

    tm_tick = range(xvals[1], xvals[end], step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")

    title = Dates.format(X_date[ii], "yyyy-mm-dd HH:MM") * " " * string(length(errors)) * " GPS errors"

    p1 = plot_pp(xvals, errors, :blue, X_data, ii, 1, tm_tick, ticks)
    p2 = plot_pp(xvals, errors, :red, X_data, ii, 2, tm_tick, ticks)
    p3 = plot_pp(xvals, errors, :green, X_data, ii, 3, tm_tick, ticks)
    
    plot_p1_p2_p3 = plot(p1, p2, p3, size=(2000, 1000), layout=(3,1), dpi=100, framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
        legend=:topright, xtickfont=font(8), ytickfont=font(8), bottommargin=5Plots.mm, suptitle=title, 
        yformatter = y -> @sprintf("%.1f", y),
        grid=true, gridlinewidth=0.125, gridstyle=:dot, gridcolor=:grey, gridalpha=0.5)
    
    display(plot_p1_p2_p3)

end    # do_plots()


#######################################################################################################
#######################################################################################################
#######################################################################################################

X_data = Matrix{Float32}(undef, 0, 0)
# Initialize GPS_errors as an empty 3D array
GPS_errors = Array{Int16}(undef, 2304, 0) #, 1)
        
infil = pick_file(initial_path)

REC_LENGTH = 2304       # Number of WSE's in a Mk4 30-minute record
SAMPLE_FREQUENCY = 1.28 # Mk4 sample frequency in Hertz
SAMPLE_LENGTH = 1800    # Record length in seconds
SAMPLE_RATE = Float64(1 / SAMPLE_FREQUENCY) # Sample spacing in seconds

# Initialize output containers
X_date = DateTime[]  # Vector to store timestamps for the selected file
X_data = Float32[]   # Placeholder for water surface elevation data (Heave, North, West)

@time begin
    println("Reading BINARY data from ", infil)
    flush(stdout)
    
    data_array = reinterpret(UInt8, read(infil))

    ii = 1
    num_records = 0  # Start counting records
    while ii < length(data_array)
        # Extract message header
        message_length = (UInt16(data_array[ii + 2]) << 8) | UInt16(data_array[ii + 3])

        # Extract timestamp
        yr = (UInt16(data_array[ii + 5]) << 8) | UInt16(data_array[ii + 6])
        month = data_array[ii + 7]
        day = data_array[ii + 8]
        hour = data_array[ii + 9]
        minute = data_array[ii + 10]

        # Store timestamp for the record
        push!(X_date, DateTime(yr, month, day, hour, minute))

        # Validate sample frequency
        sample_rate_hex = UInt32(data_array[ii + 11]) << 24 | UInt32(data_array[ii + 12]) << 16 | 
                          UInt32(data_array[ii + 13]) << 8 | UInt32(data_array[ii + 14])
        sample_frequency = reinterpret(Float32, sample_rate_hex)

        if sample_frequency != 1.28f0
            error("Error: Sample rate not 1.28 Hz - Program terminated!")
        end

        rows = (message_length - 10) ÷ 6  # Calculate number of rows (samples)
        if rows != REC_LENGTH
            error("Error: Number of rows per record does not match expected sample count!")
        end

        # Allocate temporary vectors for current record
        heave_values = Vector{Float32}(undef, rows)
        north_values = Vector{Float32}(undef, rows)
        west_values = Vector{Float32}(undef, rows)
        gps_error = Vector{Int16}(undef, rows)  # Temporary GPS error vector for the current record

        for jj ∈ 1:rows
            base_idx = ii + 15 + (jj - 1) * 6

            # Parse HEX and calculate displacements
            heave_values[jj] = reinterpret(Int16, UInt16(parse_hex(data_array, base_idx))) / 100
            north_hex = parse_hex(data_array, base_idx + 2)
            north_values[jj] = reinterpret(Int16, UInt16(north_hex)) / 100
            west_values[jj] = reinterpret(Int16, UInt16(parse_hex(data_array, base_idx + 4))) / 100

            # Determine GPS error from the least significant bit of the North value
            gps_error[jj] = parse(Int, last(string(north_hex, base=2, pad=16), 1))  # LSB: 1 (error) or 0 (no error)
        end

        # Append the GPS error for the current record
##        GPS_errors = cat(GPS_errors, reshape(gps_error, rows, 1, 1), dims=2)
        GPS_errors = hcat(GPS_errors, gps_error)
        
        # Incrementally add record data to X_data
        if isempty(X_data)
            X_data = zeros(Float32, rows, 1, 3)  # Initialize with the first record
        else
            X_data = cat(X_data, reshape([heave_values north_values west_values], rows, 1, 3), dims=2)
        end

        # Update counters
        num_records += 1
        ii += message_length + 6
    end

    println("File processing complete: ", num_records, " records processed.")
end

# Locate GPS errors flagged by Datawell
column_sums = sum(GPS_errors[:, :, 1], dims=1)  # Sum along the rows (dimension 1)

column_sums_vector = vec(column_sums)  # Converts the 1×288 matrix to a 288-element vector
records_with_errors = findall(x -> x > 0, column_sums_vector)

println("\n",string(length(records_with_errors))," records with GPS errors: ")

if isempty(records_with_errors)   
    println("\nNo GPS errors flagged by Datawell in ",infil)
else
    foreach(ii -> println("    ", Dates.format(X_date[ii], "yyyy-mm-dd HH:MM")), records_with_errors)
    foreach(ii -> do_plots(ii, X_date, X_data, GPS_errors), records_with_errors)
end    

In [None]:
X_data

In [None]:
# Locate GPS errors flagged by Datawell
column_sums = sum(GPS_errors[:, :, 1], dims=1)  # Sum along the rows (dimension 1)

column_sums_vector = vec(column_sums)  # Converts the 1×288 matrix to a 288-element vector
records_with_errors = findall(x -> x > 0, column_sums_vector)

plot(X_data[:,66,1], label="", size=(2000,400))

### Display spectra for records with GPS errors as flagged by Datawell

In [None]:
# Function to calculate f2 and Pden2 using Welch's method
function calculate_spectra(heave_row, sample_frequency)
#######################################################
    ps_w = welch_pgram(heave_row, 256, 128; onesided=true, nfft=256, fs=sample_frequency, window=hanning)
    f2 = freq(ps_w)
    Pden2 = power(ps_w)
    return f2, Pden2
end    # calculate_spectra()


# Plot the spectra
function plot_spectra(f2, Pden2, X_date)
########################################
    
    title = Dates.format(X_date, "yyyy-mm-dd HH:MM")

    p1 = plot(f2, Pden2, title=title, fillrange=:0, label="", xlims=(0,0.64), ylims=(0,Inf), size=(1000,600), framestyle=:box)

    display(p1)

end    # plot_spectra()


# Perform spectral calculations
sample_frequency = 1.28f0
num_records = length(X_date)  # Adjust as needed
nfft = 256         # FFT size used in Welch's method
n_bins = nfft ÷ 2 + 1  # Number of frequency bins (for onesided FFT)

# Initialize 2D matrices
f2_list = zeros(Float32, num_records, n_bins)
Pden2_list = zeros(Float32, num_records, n_bins)

for file_idx ∈ 1:num_records
    heave_row = X_data[:, file_idx, 1]
    f2, Pden2 = calculate_spectra(heave_row, sample_frequency)
    f2_list[file_idx, :] = f2
    Pden2_list[file_idx, :] = Pden2
end

if isempty(records_with_errors)   
    println("\nNo GPS errors flagged by Datawell in ",infil)
else
    foreach(ii -> println("    ", Dates.format(X_date[ii], "yyyy-mm-dd HH:MM")), records_with_errors)
    foreach(ii -> plot_spectra(f2_list[ii, :], Pden2_list[ii, :], X_date[ii]), records_with_errors)
end  

### Display spectrogram for selected .RDT file

In [None]:
println("Spectral calculations complete - now plotting spectrogram!")

using Plots: contourf, cgrad
x = X_date
y = f2_list[1,:]
z = Pden2_list'

contourf(x, y, z, size=(1200,600))

# display plots to screen
tm_tick = range(X_date[1],X_date[end],step=Hour(4))
ticks = Dates.format.(tm_tick,"dd HH:MM")

p1 = contourf(x, y, z, lw=0.25, c=cgrad(:Spectral, rev=true), clims=(0.0,maximum(z)), levels=10, fill=true)

# draw grid lines on plot
for i in 0:0.1:0.6
    hline!(p1, [i], lw=0.5, c=:white, label="")
end

for i in X_date[1]:Hour(4):X_date[end]
    vline!(p1, [i], lw=0.5, c=:white, label="")
end

foreach(ii -> vline!([X_date[ii]], lw=1, ls=:dash, c=:yellow, label=""), records_with_errors)

title=Dates.format(X_date[1], "yyyy-mm-dd HH:MM")*" to "*Dates.format(X_date[end], "yyyy-mm-dd HH:MM")
p1_plot = plot(p1, xlabel="Date", xlim=(X_date[1],X_date[end]), xticks=(tm_tick,ticks), xtickfontsize=7, xrotation=90,
        ylabel="Frequency (Hz)", ylim=(0,0.4), ytickfontsize=8, 
        title=title, framestyle = :box,
        leftmargin = 15Plots.mm, bottommargin = 15Plots.mm, grid=true, size=(2000,800), gridlinewidth=0.5, gridstyle=:dot, gridalpha=1, colorbar=true)

display(p1_plot)

In [None]:
function calc_and_plot_bounds(xvals, heave, lc, hspan_color, label, errors)
###################################################################

    Q1 = quantile(heave, 0.25)
    Q3 = quantile(heave, 0.75)

    multiplier = 1.5

    IQR = Q3 - Q1
    lower_bound = Q1 - multiplier * IQR
    upper_bound = Q3 + multiplier * IQR

    # Calculate dynamic confidence interval
    confidence_interval = 3.29 # threshold at the 99.9th percentile level

    # Identify z_scores using modified z-score
    z_score_indices, mod_z_scores = modified_z_score(heave, confidence_interval)

    # Plot confidence limits
    confidence_limits = calc_confidence_limits(heave, confidence_interval)

    tm_tick = range(xvals[1], xvals[end], step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")


    px = plot(xvals, heave, xlims=(xvals[1], xvals[end]), lw=0.5, lc=lc, alpha=0.5, 
        xticks=(tm_tick, ticks), label=label, legendfontsize=12)

    px = vline!(xvals[errors], label="", lc=:red, ls=:dot)

    if !isempty(z_score_indices)
        scatter!(px, xvals[z_score_indices], heave[z_score_indices], 
            markersize=4, markerstrokecolor=:red, markerstrokewidth=1, 
            markercolor=:white, markershape=:circle, label="Modified Z-score beyond 99.9% confidence limits")
    end

    px = hspan!([lower_bound, upper_bound], fillcolor=hspan_color, fillalpha=:0.125, label="IQR limits")
    px = hline!([confidence_limits[1], confidence_limits[2]], color=:red, lw=0.5, linestyle=:dash, label="99.9% confidence limits")            
    
    return(px)

end    # calc_and_plot_bounds()


function do_heave_north_west_plots(ii, start_time, X_data, GPS_errors)
##########################################################

    heave = X_data[:, ii, 1]
    north = X_data[:, ii, 2]
    west = X_data[:, ii, 3]

    errors = findall(x -> x == 1, GPS_errors[:, ii, 1])
    
    end_time = start_time + Minute(30)
    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)
   
    tm_tick = range(start_time, end_time, step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")
            
    p1 = calc_and_plot_bounds(xvals, heave, :blue, :lightblue, "Heave", errors)

    p2 = calc_and_plot_bounds(xvals, north, :red, :pink, "North", errors)

    p3 = calc_and_plot_bounds(xvals, west, :green, :lightgreen, "West", errors)

    # Annotate plot with the number of outliers and confidence interval
##    num_outliers = length(z_score_indices)
##    suspect_string = string("  ", string(ii)," ",Dates.format(start_time, "yyyy-mm-dd HH:MM"), " - ", num_outliers, " Possible outliers") # using Confidence Interval of ", 
##        @sprintf("%.2f", confidence_interval))
##    annotate!(p1, xvals[1], maximum(heave) * 0.9, text(suspect_string, :left, 10, :blue))

    date_string = Dates.format(start_time, "yyyy-mm-dd HH:MM") * " " * string(length(errors)) * " GPS errors"
 
    plot_displacements = plot(p1, p2, p3, size=(2000, 1000), layout=(3,1), dpi=100, framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
    legend=:topright, xtickfont=font(8), ytickfont=font(8), bottommargin=5Plots.mm, suptitle=date_string,
    grid=true, gridlinewidth=0.125, gridstyle=:dot, gridcolor=:grey, gridalpha=0.5)
 
    display(plot_displacements)
    
end    # do_heave_north_west_plots()  

for ii ∈ records_with_errors

    # Initialize variables
    start_time = X_date[ii]
       
    do_heave_north_west_plots(ii, start_time, X_data, GPS_errors)

end

### Initial declaration of training data matricies

### Build training data matricies

In [None]:
records_with_errors = records_with_errors[2:end]

In [None]:
# create a matrix of records with GPS errors in them
new_training_data_bad = X_data[:,records_with_errors,1]
println(string(length(records_with_errors))," records stored to training_data_bad matrix")

# Find indices of good records (not in records_with_errors)
all_records = 1:size(X_data, 2)  # All column indices of X_data
records_without_errors = setdiff(all_records, records_with_errors)

# Create training_data_good
new_training_data_good = X_data[:, records_without_errors, 1];
println(string(length(records_without_errors))," records stored to training_data_good matrix")



### Add new training data to existing matricies and save to .JLD2 file

In [None]:
using JLD2, Dates

##training_data_good = hcat(training_data_good, new_training_data_good)
training_data_bad = hcat(training_data_bad, new_training_data_bad)

println("Good data now ",string(size(training_data_good)[2])," records")
println("Bad  data now ",string(size(training_data_bad)[2])," records\n")

outfil = ".\\Training_data\\Mk3_training_data_updated_" * Dates.format(now(), "yyyy_mm_dd_HHMM") * ".JLD2" 

# Save the updated good and bad training data
@save outfil training_data_good training_data_bad

println("Updated training data saved successfully.")


# *******************************************************************************

### Recover earlier separated training data from file (Note: does not include Model data)

In [None]:
# Load the model and optimizer states from the JLD2 file
current_path = pwd() * "\\Training_data"
filterlist = "JLD2"
infil = pick_file(current_path; filterlist)

# Load all saved training data
println("\nLoading training data from ",infil)
flush(stdout)    
@load infil training_data_good training_data_bad training_data_bad

println("Data and labels loaded successfully.")
println("Good data contains ",string(size(training_data_good)[2])," records")
println("Bad  data contains ",string(size(training_data_bad)[2])," records\n")

In [None]:
training_data_good # = training_data_good[:,2:end]

### Build the Model using a hybrid approach¶

In [None]:
#==
Calls:

    min_max_normalize_matrix()

==#

# Define autoencoder model
hybrid_model = Chain(
    Dense(2304, 256, relu),
    Dense(256, 128, relu),
    Dense(128, 32, relu),
    Dense(32, 128, relu),
    Dense(128, 256, relu),
    Dense(256, 2304)
)

# Concatenate and normalize the training data
training_data_combined = hcat(training_data_good, training_data_bad)
training_data_normalized = min_max_normalize_matrix(training_data_combined)
training_data_float32 = Float32.(training_data_normalized)

# approx. computer-specific model build time
if hostname == "QUEENSLAND-BASIN"  
    println("Building hybrid model now - on this computer it takes about 30s\n")
else   
    println("Building hybrid model now - on this computer it takes about 200s\n")
end   

flush(stdout)  

@time begin
    
    # Train the model
    loss(x) = Flux.mse(hybrid_model(x), x)
    opt = Adam()
    
    data = Iterators.repeated((training_data_float32,), 100)
    Flux.train!(loss, Flux.params(hybrid_model), data, opt)
    println("Model training complete.")

end

println("\nNow select a .RDT file to check for outliers")

### Select a .RDT file to check for outliers

In [None]:
# Helper function to convert data without creating temporary strings
function parse_hex(data_array, idx)
###################################
    
    UInt16(data_array[idx]) << 8 | UInt16(data_array[idx+1])
    
end    # parse_hex()


# Function to check if the LSB is 1 (GPS interference)
function check_gps_flag(north_row)
##################################
    
    gps_flag_row = [((n & 0x1) == 1) ? 1 : 0 for n in north_row]
    
    return(gps_flag_row)
    
end    # check_gps_flag()


hostname = gethostname()
##println("The name of the computer is: ", hostname)

if hostname == "QUEENSLAND-BASIN"
    
    display("text/html", "<style>.container { width:100% !important; }</style>")
    initial_path = "E:\\Card Data\\"
    
else
    
    display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))    
    initial_path = "F:\\Card Data\\"
    
end

X_data = Matrix{Float32}(undef, 0, 0)
# Initialize GPS_errors as an empty 3D array
GPS_errors = Array{Int16}(undef, 2304, 0) #, 1)
        
infil = pick_file(initial_path)

REC_LENGTH = 2304       # Number of WSE's in a Mk4 30-minute record
SAMPLE_FREQUENCY = 1.28 # Mk4 sample frequency in Hertz
SAMPLE_LENGTH = 1800    # Record length in seconds
SAMPLE_RATE = Float64(1 / SAMPLE_FREQUENCY) # Sample spacing in seconds

# Initialize output containers
X_date = DateTime[]  # Vector to store timestamps for the selected file
X_data = Float32[]   # Placeholder for water surface elevation data (Heave, North, West)

@time begin
    println("Reading BINARY data from ", infil)
    flush(stdout)
    
    data_array = reinterpret(UInt8, read(infil))

    ii = 1
    num_records = 0  # Start counting records
    while ii < length(data_array)
        # Extract message header
        message_length = (UInt16(data_array[ii + 2]) << 8) | UInt16(data_array[ii + 3])

        # Extract timestamp
        yr = (UInt16(data_array[ii + 5]) << 8) | UInt16(data_array[ii + 6])
        month = data_array[ii + 7]
        day = data_array[ii + 8]
        hour = data_array[ii + 9]
        minute = data_array[ii + 10]

        # Store timestamp for the record
        push!(X_date, DateTime(yr, month, day, hour, minute))

        # Validate sample frequency
        sample_rate_hex = UInt32(data_array[ii + 11]) << 24 | UInt32(data_array[ii + 12]) << 16 | 
                          UInt32(data_array[ii + 13]) << 8 | UInt32(data_array[ii + 14])
        sample_frequency = reinterpret(Float32, sample_rate_hex)

        if sample_frequency != 1.28f0
            error("Error: Sample rate not 1.28 Hz - Program terminated!")
        end

        rows = (message_length - 10) ÷ 6  # Calculate number of rows (samples)
        if rows != REC_LENGTH
            error("Error: Number of rows per record does not match expected sample count!")
        end

        # Allocate temporary vectors for current record
        heave_values = Vector{Float32}(undef, rows)
        north_values = Vector{Float32}(undef, rows)
        west_values = Vector{Float32}(undef, rows)
        gps_error = Vector{Int16}(undef, rows)  # Temporary GPS error vector for the current record

        for jj ∈ 1:rows
            base_idx = ii + 15 + (jj - 1) * 6

            # Parse HEX and calculate displacements
            heave_values[jj] = reinterpret(Int16, UInt16(parse_hex(data_array, base_idx))) / 100
            north_hex = parse_hex(data_array, base_idx + 2)
            north_values[jj] = reinterpret(Int16, UInt16(north_hex)) / 100
            west_values[jj] = reinterpret(Int16, UInt16(parse_hex(data_array, base_idx + 4))) / 100

            # Determine GPS error from the least significant bit of the North value
            gps_error[jj] = parse(Int, last(string(north_hex, base=2, pad=16), 1))  # LSB: 1 (error) or 0 (no error)
        end

        # Append the GPS error for the current record
##        GPS_errors = cat(GPS_errors, reshape(gps_error, rows, 1, 1), dims=2)
        GPS_errors = hcat(GPS_errors, gps_error)
        
        # Incrementally add record data to X_data
        if isempty(X_data)
            X_data = zeros(Float32, rows, 1, 3)  # Initialize with the first record
        else
            X_data = cat(X_data, reshape([heave_values north_values west_values], rows, 1, 3), dims=2)
        end

        # Update counters
        num_records += 1
        ii += message_length + 6
    end

    println("File processing complete: ", num_records, " records processed.")
end


### Run the hybrid model against data in the selected .RDT file

In [None]:
#==
Calls:

    detect_outliers()

==#

@time begin
    
    # identify possible outliers in Heave, North, and West data
    outlier_heave, uncertain_heave, outlier_dates_heave, uncertain_dates_heave, good_thresh_heave, bad_thresh_heave = 
        detect_outliers(X_data[:,:,1], X_date, training_data_good, training_data_bad, hybrid_model)
    outlier_north, uncertain_north, outlier_dates_north, uncertain_dates_north, good_thresh_north, bad_thresh_north = 
        detect_outliers(X_data[:,:,2], X_date, training_data_good, training_data_bad, hybrid_model)
    outlier_west, uncertain_west, outlier_dates_west, uncertain_dates_west, good_thresh_west, bad_thresh_west = 
        detect_outliers(X_data[:,:,3], X_date, training_data_good, training_data_bad, hybrid_model)
    
    # Combine and deduplicate dates across components
    all_outlier = unique(vcat(outlier_heave, outlier_north, outlier_west))
    all_uncertain = unique(vcat(uncertain_heave, uncertain_north, uncertain_west))
    
    # Combine and deduplicate dates across components
    all_outlier_dates = unique(vcat(outlier_dates_heave, outlier_dates_north, outlier_dates_west))
    all_uncertain_dates = unique(vcat(uncertain_dates_heave, uncertain_dates_north, uncertain_dates_west));

end

# Output results
println("\nFor ",infil,"\n")
if !isempty(all_outlier_dates)
    println(string(length(all_outlier_dates)), " records contain suspected outliers at the following dates:\n")
    for date in all_outlier_dates
        println("    ", Dates.format(date, "yyyy-mm-dd HH:MM"))
    end
    print("\n")
else
    println("No suspected outliers detected.\n")
end

if !isempty(all_uncertain_dates)
    println(string(length(all_uncertain_dates)), " records contain uncertain data points at the following dates:\n")
    for date in all_uncertain_dates
        println("    ", Dates.format(date, "yyyy-mm-dd HH:MM"))
    end
else
    println("No uncertain data points detected.")
end

println("\nNow run the plot routine to view the suspect records")

### Plot records with suspect data (as identified by the model)

In [None]:
#==
Calls:

    do_heave_north_west_plots()

==#

combined_outlier = unique(vcat(all_outlier, records_with_errors))

jj = 1

for ii ∈ sort(combined_outlier) # all_outlier)

    # Initialize variables
    start_time = X_date[ii]
       
    do_heave_north_west_plots(ii, start_time, X_data, GPS_errors)

    jj += 1
    
end

### Select new bad-data to be added to training_data_bad

In [None]:
new_training_data_bad = zeros(Float32, 2304, 0)

# Enter indicies of bad data from plots
bad_data_indices = [3] # <----- Change these values!

for ii in bad_data_indices
    
    kk = all_outlier[ii]
    
    new_column = reshape(X_data[:, kk, 1], :, 1)  # Ensure it’s a column vector
    new_training_data_bad = hcat(new_training_data_bad, new_column)
    
end


### Update the training_data_bad and create an updated training data file

In [None]:
##training_data_good = hcat(training_data_good, new_training_data_good)
training_data_bad = hcat(training_data_bad, new_training_data_bad)

println("Good data now ",string(size(training_data_good)[2])," records")
println("Bad  data now ",string(size(training_data_bad)[2])," records\n")

outfil = ".\\Training_data\\Mk3_training_data_updated_" * Dates.format(now(), "yyyy_mm_dd_HHMM") * ".JLD2" 

# Save the updated good and bad training data
@save outfil training_data_good training_data_bad

println("Updated bad training data saved successfully.")
println("NOTE: you will need to download the new training data file and build the model again!")

In [None]:
for ii in bad_data_indices
    
    kk = all_outlier[ii]
    
    start_time = X_date[kk]
    p1 = plot!(X_data[:,kk,1], lc=:blue, alpha=:0.5, label=string("Heave"))
    end_time = start_time + Minute(30)
    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)

    p1 = plot(size=(2500,300), framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, )
    errors = findall(x -> x == 1, GPS_errors[:, all_outlier[ii], 1])
    
    p1 = vline!([errors], lc=:red, alpha=:0.5, label="")
    
    p1 = plot!(X_data[:,kk,1], lc=:blue, alpha=:0.5, label=string(X_date[kk]))
    
    display(p1)
    
end

### Save training data to file

In [None]:
outfil = ".\\Training_data\\Mk3_training_data_updated_" * Dates.format(now(), "yyyy_mm_dd_HHMM") * ".JLD2" 

# Save the updated good and bad training data
@save outfil training_data_good training_data_bad


### Select .RDT directory and read its files

In [None]:
#==
Calls:

    get_sorted_file_data()
    plot_all_directory()

==#

# Define the path to the directory you want to search in
directory_path = pick_folder(initial_path)

# Use glob to find all .RDT files in the directory and subdirectories
println("Reading all .RDT files in ",directory_path)
flush(stdout)
rdt_files = glob(".//*.RDT", directory_path)

# remove the TMP.RDT file from the array
TMP_file = "TMP.RDT"
rdt_files = filter(file -> basename(file) != TMP_file, rdt_files)

# Extract sorted dates, Hmax_values, and rdt_files
sorted_dates, sorted_rdt_files, sorted_Hsig_values = get_sorted_file_data(infil, rdt_files)
date_array = Dates.format.(values(sorted_dates), "yyyy-mm-dd");
  
plot_all_directory(sorted_dates, sorted_Hsig_values)

### Select a .RDT file from menu

In [None]:
#==
Calls:

    select_date_from_list()

==#

dates_array = Dates.format.(sorted_dates, "yyyy-mm-dd")

selected_date = select_date_from_list(dates_array)
println("Selected date: ", selected_date === nothing ? "None" : selected_date)

index = findall(x -> x == selected_date,string.(sorted_dates));


### Read the selected .RDT file

In [None]:
infil = directory_path * "\\" * sorted_rdt_files[index[1]]

# Initialize GPS_errors as an empty 3D array
        
REC_LENGTH = 2304       # Number of WSE's in a Mk4 30-minute record
SAMPLE_FREQUENCY = 1.28 # Mk4 sample frequency in Hertz
SAMPLE_LENGTH = 1800    # Record length in seconds
SAMPLE_RATE = Float64(1 / SAMPLE_FREQUENCY) # Sample spacing in seconds

X_data, X_date, GPS_errors = decode_rdt_data(infil)

# Identify zero columns and remove them
zero_columns = findall(i -> all(==(0), X_data[:, i, :]), 1:size(X_data, 2))
##X_data = X_data[:, setdiff(1:size(X_data, 2), zero_columns), :]
##X_date = X_date[setdiff(1:length(X_date), zero_columns)]

# Locate GPS errors flagged by Datawell
column_sums = sum(GPS_errors[:, :, 1], dims=1)  # Sum along the rows (dimension 1)

column_sums_vector = vec(column_sums)  # Converts the 1×288 matrix to a 288-element vector
records_with_errors = findall(x -> x > 0, column_sums_vector)

println("\n",string(length(records_with_errors))," records with GPS errors: ")


In [None]:
X_date[records_with_errors]

### Plot suspected outlier records

In [None]:
for ii ∈ records_with_errors

    # Initialize variables
    start_time = X_date[ii]
       
    do_heave_north_west_plots(ii, start_time, X_data, GPS_errors)

end

### Save model, optimiser, and training data to file

In [None]:
using Flux, JLD2, Dates

# Save the model state, optimizer state, and relevant data/labels
outfil = ".\\Model\\RDT_model_and_data_"*Dates.format(now(), "yyyy_mm_dd_HHMM")*".JLD2" 

# Save model and optimizer and normalised wave data to a JLD2 file
@save outfil hybrid_model opt training_data_good training_data_bad

println("Data and model saved successfully to ",outfil)


### Recover .RDT hybrid_model, optimiser, and training data from file

In [None]:
using JLD2, Flux
using NativeFileDialog: pick_file

# Select the JLD2 file
current_path = pwd() * "\\Model"
filterlist = "JLD2"
infil = pick_file(current_path; filterlist)

# Ensure the file extension is uppercase
infil = replace(infil, ".jld2" => ".JLD2")

println("File selected: $infil")

# Verify keys in the JLD2 file
JLD2.jldopen(infil, "r") do file
    println("Keys in saved file:", keys(file))
end

# Load the saved data into variables
@load infil hybrid_model opt training_data_good training_data_bad

# Verify the loaded data
println("Model type: ", typeof(hybrid_model))         # Should be Chain
println("Optimizer type: ", typeof(opt))             # Should be an optimizer, e.g., Adam
println("Good training data type: ", typeof(training_data_good))  # Should be Array
println("Bad training data type: ", typeof(training_data_bad))    # Should be Array

println("Model and training data loaded")

### Update hybrid_model and optimiser by adding new training data

In [None]:
using Flux, IterTools


Combine existing and new training data
updated_training_data_good = hcat(training_data_good, new_training_data_good)
updated_training_data_bad = hcat(training_data_bad, new_training_data_bad)

Concatenate good and bad training data
updated_training_data_combined = hcat(updated_training_data_good, updated_training_data_bad)

Normalize the updated training data
updated_training_data_normalized = min_max_normalize_matrix(updated_training_data_combined)
updated_training_data_float32 = Float32.(updated_training_data_normalized)

Retrain the model with the updated data
println("Updating the model with new training data...")

loss(x) = Flux.mse(hybrid_model(x), x)
data = Iterators.repeated((updated_training_data_float32,), 100)  # Number of epochs
Flux.train!(loss, Flux.params(hybrid_model), data, opt)

println("Model updated with new training data!")

# 5. Save the updated model, optimizer, and training data
@save infil hybrid_model opt updated_training_data_good updated_training_data_bad
println("Updated model and data saved successfully!")
