In [None]:
using DataFrames: DataFrame, ncol, nrow
using Dates: Dates, DateTime, Time, unix2datetime, Hour, Minute, Microsecond
using NativeFileDialog: pick_file
using Statistics: median, mean, std
using Plots: Plots, plot, plot!, annotate!, hline!, @layout, text, plotly, font, scatter!
using Printf: @sprintf

#using PlotlyKaleido

#plotly()
#PlotlyKaleido.start(timeout=30)


function get_matches(Data, f23_df)
##################################
    
    # Create a dictionary to store indices of hex strings in Data
    index_dict = Dict{String, Vector{Int}}()
    
    # Populate the dictionary
    for (i, hex_str) in enumerate(Data)
        if haskey(index_dict, hex_str)
            push!(index_dict[hex_str], i)
        else
            index_dict[hex_str] = [i]
        end
    end
    
    # Initialize a vector to store indices
    matching_indices = []
    
    # Iterate through each hex string in f23_df and lookup in the dictionary
    for hex_str in f23_df.Match_vector
        if haskey(index_dict, hex_str)
            push!(matching_indices, index_dict[hex_str][1])
        else
            push!(matching_indices, nothing)  # If no match found, store an empty vector
        end
    end

    f23_df[!,"Data_vector"] = matching_indices

    return(f23_df)

end    # get_matches()

# function to calculate selected parameters from Spectrum synchronisation message (0xF23)
function process_f23(f23_vals)
#######################################
    
    # refer to DWTP (Ver. 16 January2019) Section 4.3 pp.43-44

    # get Timestamp in UTC
    timestamp = unix2datetime(parse(Int, bitstring(f23_vals[3]) * bitstring(f23_vals[4]) * bitstring(f23_vals[5]) * bitstring(f23_vals[6]); base=2))
    
    # convert time to Australian Eastern Standard Time
    timestamp = timestamp + Hour(0)  # Adjust this for the correct time zone

    # get Data Stamp
    data_stamp = parse(Int, bitstring(f23_vals[7]) * bitstring(f23_vals[8]); base=2)

    # get Segments Used
    segments_used = parse(Int, bitstring(f23_vals[9]) * bitstring(f23_vals[10]) * bitstring(f23_vals[11]); base=2)

    # get Sample Number
    sample_number = parse(Int, bitstring(f23_vals[12]) * bitstring(f23_vals[13]); base=2)

    # Create Match Vector
    match_vector = lpad(string(f23_vals[14], base=16), 2, "0")
    for i in 15:22
        match_vector = match_vector * lpad(string(f23_vals[i], base=16), 2, "0")
    end
    
    return(timestamp, segments_used, match_vector, sample_number)
    
end    #  process_f23()


# convert binary data into F23_df and Hex array
function get_hex_array(infil)
#############################
    
    # Read binary data from the input file
    println("Reading BINARY data from ", infil)
    flush(stdout)
    data = reinterpret(UInt8, read(infil))
    
    # Turn the data vector into a matrix of 12 values matching hexadecimal bytes
    cols = 12
    rows = Int(length(data) / cols)
    mat = reshape(view(data, :), cols, :)
    
    # Calculate the Heave, North, and West displacements
    hex_matrix = string.(mat'[:,1:9], base=16, pad=2)
    Data = [join(row) for row in eachrow(hex_matrix)]
    
    println("All file data read!")
    
    # Interleave the last 3 matrix columns (10, 11, 12) to form the packet vector
    packet = collect(Iterators.flatten(zip(mat[10,:], mat[11,:], mat[12,:])))
    
    # Find all occurrences of 0x7e in the packet vector
    aa = findall(x -> x == 0x7e, vec(packet))
    
    # Create DataFrame to hold the processed data
    f23_df = DataFrame(Date = DateTime[], Segments = Int[], Match_vector = String[], Sample_number = Int[])
    
    # Decode the packet data into messages
    max_val = length(aa) - 1
    
    for i in 1:max_val
        first = aa[i] + 1
        last = aa[i + 1]
        
        if (last - first > 1)
            decoded = packet[first:last-1]
            
            # Handle the 0x7d escape sequences (XOR with 0x20)
            bb = findall(x -> x == 0x7d, decoded)
            for ii in bb
                decoded[ii + 1] = decoded[ii + 1] ⊻ 0x20
            end
            deleteat!(decoded, bb)
            
            # If the message is F23 (0x23)
            if decoded[2] == 0x23
                timestamp, segments_used, match_vector, sample_number = process_f23(decoded)
                push!(f23_df, [timestamp, segments_used, match_vector, sample_number])
            end
        end
    end
    
    # Remove duplicates from f23_df
    f23_df = unique(f23_df);

    return(f23_df, Data)
    
    end    # get_hex_array()


function get_start_end_dates(f23_df,found_list)
###############################################
    
    start_date = f23_df[found_list[1],:].Date - Minute(30) # <------- NOTE subtracted 30min from start_date to match Waves4 results
    segments = f23_df[found_list[1],:].Segments
#   match_vector = f23_df[found_list[1],:].Match_vector
    sample_nos = f23_df[found_list[1],:].Sample_number
    data_vector = f23_df[found_list[1],:].Data_vector
    start_val = data_vector - Int(sample_nos/2) + 1
    end_val = data_vector
    
    return(start_date,start_val, end_val)
    
end    #(get_start_end_dates)
    
  
function get_displacement(Data, start_val, end_val)
################################################
# Decode the real time data to displacements - See DWTP (16 Jan 2019) 2.1.1 p. 19    
    
    arry = collect(Iterators.flatten(zip(SubString.(Data, start_val, end_val),SubString.(Data, start_val+9, end_val+9))));
    displacements = [parse(Int, SubString.(i, 1, 1), base=16)*16^2 + parse(Int, SubString.(i, 2, 2), base=16)*16^1 + parse(Int, SubString.(i, 3, 3), base=16)*16^0 for i in arry]    
    
    displacements[findall(>=(2048), displacements)] = displacements[findall(>=(2048), displacements)] .- 4096;
    displacements = 0.457*sinh.(displacements/457)    # see DWTP p.19 (16)
   
    return(displacements)
    
end    # get_displacement()


function get_hnw(Data,start_val,end_val)
######################################## 
    
    # get WSEs for desired 30-minute record
    heave = get_displacement(Data[start_val:end_val,:], 1, 3);              
    north = get_displacement(Data[start_val:end_val,:], 4, 6);
    west = get_displacement(Data[start_val:end_val,:], 7, 9);
    
    # Check for missing or extra points in data
    for wse in [heave, north, west]
        
        wse_length = length(wse)
        
        if wse_length > REC_LENGTH

            # truncate if too long
            wse = wse[1:REC_LENGTH]
            
        else

            # zero pad if too short (leave it unchanged if right length)
            append!(wse,zeros(REC_LENGTH-wse_length))
            
        end      

    end
    
    return (heave, north, west)
    
end    # get_hnw()


# Function to calculate confidence limits
function calc_confidence_limits(data, confidence_interval)
##########################################################
    
    mean_val = mean(data)
    std_dev = std(data)
    upper_limit = mean_val + confidence_interval * std_dev
    lower_limit = mean_val - confidence_interval * std_dev
    
    return (lower_limit, upper_limit)
    
end    # calc_confidence_limits()


# Function to compute modified z-scores and find outliers
function modified_z_score(data, threshold)
##########################################
    
    med = median(data)
    mad = median(abs.(data .- med))
    mod_z_scores = 0.6745 * (data .- med) ./ mad
    outlier_indices = findall(x -> abs(x) > threshold, mod_z_scores)
    
    return(outlier_indices, mod_z_scores)
    
end    # modified_z_score()


# Function for dynamic threshold based on mean wave height
function dynamic_z_score_threshold(heave, base_threshold=3.0, k=0.5)
    
    mean_wave_height = mean(heave)
    std_wave_height = std(heave)
    dynamic_threshold = base_threshold * (1 + k * (mean_wave_height / std_wave_height))
    
    return(dynamic_threshold)
    
end    # dynamic_z_score_threshold()


function pad_or_truncate(record, target_length=REC_LENGTH)
####################################################

    length(record) < target_length ? vcat(record, zeros(Float32, target_length - length(record))) :
                                     record[1:target_length]

end    # pad_or_truncate()


function get_heave(Data, f23_df)
################################
    
    heave_array = []
    X_date = []

    println("Calculating Heave values now!")
    
    for idx in 1:nrow(f23_df)

        if !isnothing(f23_df.Data_vector[idx])
    
            start_date, start_val, end_val = get_start_end_dates(f23_df,idx)
            if start_val > 0
                print(".")
                heave, north, west = get_hnw(Data,start_val,end_val)

                # ensure we have REC_LENGTH data points
                push!(heave_array,pad_or_truncate(heave, REC_LENGTH))
                push!(X_date,start_date)
            end

        end
    
    end

    return(hcat(heave_array...), X_date)

end    # get_heave()


# Need to check first row of the f23_df in case 23:00 is stored there
function f23_first_row_check(f23_df)
################################
    
    # Get the first row of the DataFrame
    first_row = first(f23_df)
    
    # Check if the time of the first row's Date column is 23:00:00
    time_of_first_row = Time(first_row.Date)

    if time_of_first_row == Time(23, 0, 0)

        if ismissing(first_row.Data_vector) || isnothing(first_row.Data_vector) || isnan(first_row.Data_vector)
            f23_df = f23_df[2:end, :]  # Drop the first row
        end

    end
    
    return(f23_df)
    
    end    # f23_first_row_check()


###############################################################################
###############################################################################
###############################################################################

const REC_LENGTH = 4608
const SAMPLE_FREQUENCY = 2.56 # sample frequency in Hertz
const SAMPLE_LENGTH = 1800 # record length in seconds
const SAMPLE_RATE = Float64(1/SAMPLE_FREQUENCY) # sample spacing in seconds

# Widen screen for better viewing
display(HTML("<style>.jp-Cell { width: 120% !important; }</style>"))

infil = pick_file()

f23_df, Data = get_hex_array(infil)

f23_df = get_matches(Data, f23_df)

# remove those vectors from F23 df that are not located in the Data vector df
f23_df = f23_first_row_check(f23_df)

X_train, X_date = get_heave(Data, f23_df);

println("\nNow preparing to plot heave for each record!")
flush(stdout)

### Plot each record

In [None]:
##==
# Loop through wave records
for ii in 1:10 #length(X_date)
    
    # Initialize variables
    start_time = X_date[ii]
    heave = X_train[:, ii]
    end_time = start_time + Minute(30)
    xvals = start_time + Microsecond.((0:REC_LENGTH-1) / SAMPLE_FREQUENCY * 1000000)

    # Plot initialization
    p1 = plot(size=(1200, 300), dpi=100, framestyle=:box, fg_legend=:transparent, bg_legend=:transparent, 
        legend=:topright, xtickfont=font(8), ytickfont=font(8),
        grid=true, gridlinewidth=0.125, gridstyle=:dot, gridalpha=1)
    
    tm_tick = range(start_time, end_time, step=Minute(1))
    ticks = Dates.format.(tm_tick, "MM")
    
    # Calculate dynamic confidence interval
    confidence_interval = dynamic_z_score_threshold(heave)

    # Identify outliers using modified z-score
    outlier_indices, mod_z_scores = modified_z_score(heave, confidence_interval)
    if !isempty(outlier_indices)
        scatter!(p1, xvals[outlier_indices], heave[outlier_indices], 
            markersize=4, markerstrokecolor=:red, markerstrokewidth=1, 
            markercolor=:white, markershape=:circle, label="")
    end

    # Plot confidence limits
    confidence_limits = calc_confidence_limits(heave, confidence_interval)
    hline!(p1, [confidence_limits[1], confidence_limits[2]], color=:red, lw=0.5, linestyle=:dash, label="")

    # Plot heave data
    plot!(p1, xvals, heave, xlims=(xvals[1], xvals[end]), lw=0.5, lc=:blue, alpha=0.5, 
        xticks=(tm_tick, ticks), label="")

    # Annotate plot with the number of outliers and confidence interval
    num_outliers = length(outlier_indices)
    suspect_string = string("  ", Dates.format(start_time, "yyyy-mm-dd HH:MM"), " - ", num_outliers, " Possible outliers using Confidence Interval of ", 
        @sprintf("%.2f", confidence_interval))
    annotate!(p1, xvals[1], maximum(heave) * 0.9, text(suspect_string, :left, 10, :blue))

    display(p1)

end
#==#

### Select records that will NOT be uploaded to model

In [None]:
using Tk

all_dates = Dates.format.(X_date, "yyyy-mm-dd HH:MM")

w = Toplevel("Select Date", 235, 400)
tcl("pack", "propagate", w, false)
f = Frame(w)
pack(f, expand=true, fill="both")

f1 = Frame(f)
lb = Treeview(f1, all_dates)
scrollbars_add(f1, lb)
pack(f1,  expand=true, fill="both")

tcl("ttk::style", "configure", "TButton", foreground="blue", font="arial 16 bold")
b = Button(f, "Ok")
pack(b)

bad_array = []

bind(b, "command") do path
    
    file_choice = get_value(lb);
    push!(bad_array,file_choice[1])

end

# Find indices of bad_dates in X_date
bad_cols = findall(x -> x in bad_array, all_dates)

# Remove columns from X_train whose column numbers are in bad_cols
X_train = X_train[:, setdiff(1:size(X_train, 2), bad_cols)]

In [None]:
using Tk

# Convert DateTimes into strings
all_dates = Dates.format.(X_date, "yyyy-mm-dd HH:MM")

# Initialize selection window
w = Toplevel("Select Date", 235, 400)
tcl("pack", "propagate", w, false)
f = Frame(w)
pack(f, expand=true, fill="both")

f1 = Frame(f)
lb = Treeview(f1, all_dates)
scrollbars_add(f1, lb)
pack(f1, expand=true, fill="both")

# Style button
tcl("ttk::style", "configure", "TButton", foreground="blue", font="arial 16 bold")
b = Button(f, "Ok")
pack(b)

# Global array to store selected dates
global bad_array = []

# Collect dates when the button is pressed
bind(b, "command") do path
    file_choice = get_value(lb)
    if !isempty(file_choice)
        push!(bad_array, file_choice[1])
        println("Added to removal list: ", file_choice[1])
    end
end

# Function to close window and process columns
function finalize_selection(args...)
    destroy(w)  # Close Tk window

    # Find indices of bad_dates in X_date
    bad_cols = findall(x -> x in bad_array, all_dates)

    # Remove columns from X_train based on `bad_cols`
    global X_train = X_train[:, setdiff(1:size(X_train, 2), bad_cols)]
    println("Removed columns based on selection.")
end

# Bind finalize function to window close event
tcl("wm", "protocol", w, "WM_DELETE_WINDOW", finalize_selection)


In [None]:
# Assuming X_train is your matrix
first_5_cols = X_train[:, 1:5]  # Get the first 5 columns
last_5_cols = X_train[:, end-4:end]  # Get the last 5 columns

# Combine them into a new matrix for display
combined = hcat(first_5_cols, last_5_cols)  # Horizontally concatenate

println(combined)  # Display the combined matrix


In [None]:
# Assuming X_train is your matrix
first_5_rows = X_train[1:5, :]  # Get the first 5 rows
last_5_rows = X_train[end-4:end, :]  # Get the last 5 rows

first_5_cols = X_train[:, 1:5]  # Get the first 5 columns
last_5_cols = X_train[:, end-4:end]  # Get the last 5 columns

# Combine the first and last 5 rows and columns into a new matrix for display
combined = hcat(first_5_cols, last_5_cols)  # Horizontally concatenate first and last columns
display = vcat(first_5_rows, last_5_rows)  # Vertically concatenate first and last rows

# Now you can print the combined results
println("First 5 rows:")
println(first_5_rows)

println("\nLast 5 rows:")
println(last_5_rows)

println("\nCombined display (first and last 5 columns):")
println(display)


In [None]:
X_train = X_train_old

In [None]:
X_train

In [None]:
X_train_old = X_train

### Initial Code for Training an Autoencoder in Julia (using Flux.jl)

In [None]:
using Flux, Statistics

function min_max_normalize_matrix(X)
    min_vals = minimum(X, dims=1)  # Compute min for each column
    max_vals = maximum(X, dims=1)  # Compute max for each column
    return (X .- min_vals) ./ (max_vals .- min_vals)
end


function z_score_normalize_matrix(X)
    mean_vals = mean(X, dims=1)  # Mean for each column
    std_vals = std(X, dims=1)    # Standard deviation for each column
    return (X .- mean_vals) ./ std_vals
end


function pad_or_truncate(record, target_length=4608)
####################################################
#==    
    if length(record) < target_length
        # Pad with zeros (or any other value you prefer)
        return vcat(record, zeros(Float32, target_length - length(record)))
    elseif length(record) > target_length
        # Truncate to the target length
        return record[1:target_length]
    else
        return record
    end
==#
    length(record) < target_length ? vcat(record, zeros(Float32, target_length - length(record))) :
                                     record[1:target_length]

end    # pad_or_truncate()


function get_heave(Data, f23_df)
################################
    
    heave_array = []
    X_date = []
    
    for idx in 1:nrow(f23_df)

        if !isnothing(f23_df.Data_vector[idx])
    
            start_date, start_val, end_val = get_start_end_dates(f23_df,idx)
            if start_val > 0
                print(".")
                heave, north, west = get_hnw(Data,start_val,end_val)

                # ensure we have 4608 data points
                push!(heave_array,pad_or_truncate(heave, 4608))
                push!(X_date,start_date)
            end

        end
    
    end

    return(hcat(heave_array...), X_date)

end    # get_heave()


function calc_reconstruction_errors(X_train_float32, model)
####################################################
    
    reconstruction_errors = Float32[]
    
    for record in eachcol(X_train_float32)  # Each record is now a column with 14 features
        reconstructed_record = model(record)  # Pass the record to the autoencoder
        error = mean((reconstructed_record .- record).^2)  # Calculate the reconstruction error
        push!(reconstruction_errors, error)  # Store the error
    end
    
    return(reconstruction_errors)

end    # calc_reconstruction_errors()

####################################################################
####################################################################
####################################################################
@time begin
# Define autoencoder model
model = Chain(
    Dense(4608, 128, relu),  # Encoder
    Dense(128, 64, relu),    # Bottleneck
    Dense(64, 128, relu),    # Decoder
    Dense(128, 4608)         # Output layer, reconstructs input
)

# Define the loss function (e.g., Mean Squared Error for reconstruction)
loss(x) = Flux.mse(model(x), x)

# Optimizer: Adam with default parameters (learning rate, etc.)
opt = Adam()
   
X_train, X_date = get_heave(Data, f23_df)

# Normalize your training data
##X_train_normalized = normalize_records(X_train)
X_train_normalized = min_max_normalize_matrix(X_train)
    
# Convert WSE data to Float32
X_train_float32 = Float32.(X_train_normalized)

# calculate the reconstruction_errors
reconstruction_errors_model = calc_reconstruction_errors(X_train_float32, model)

# Use the converted data for training
data = Iterators.repeated((X_train_float32,), 100)  # Example of data iteration for training

# Train the model
Flux.train!(loss, Flux.params(model), data, opt)
end    # @time

### Save model, optimiser, and normalised heave data to file

In [None]:
using Flux, JLD2

# Save the model and optimizer
model_state = Flux.state(model)
opt_state = opt # Flux.setup(Adam(), model)

outfil = "HVA_model_"*Dates.format(now(), "yyyy_mm_dd_HHMM")*".jld2" 

# Save model and optimizer and normalised wave data to a JLD2 file
jldsave(outfil; model_state, opt_state, X_train_float32)


### Recover saved model, optimiser, and normalised heave data from file

In [None]:
using JLD2, Flux

# Load the model and optimizer states from the JLD2 file
infil = pick_file()
loaded_data = jldopen(infil, "r") do file
    old_model_state = file["model_state"]  # Load model state
    old_opt_state = file["opt_state"]      # Load optimizer state
    old_X_train_float32 = file["X_train_float32"]  # Load the previous X_train_float32 data
    return (old_model_state, old_opt_state, old_X_train_float32) # Return all states and data
end

old_model_state, old_opt_state, old_X_train_float32 = loaded_data

# Define the old model architecture
old_model = Chain(
    Dense(4608, 128, relu),  # Encoder
    Dense(128, 64, relu),    # Bottleneck
    Dense(64, 128, relu),     # Decoder
    Dense(128, 4608)          # Output layer, reconstructs input
)

# Restore model parameters from the loaded state
for (layer, state) in zip(old_model.layers, old_model_state[:layers])
    layer.weight .= state.weight   # Assign saved weights
    layer.bias .= state.bias       # Assign saved biases
end

# Restore the optimizer state directly
old_opt = old_opt_state;  # Assign the loaded optimizer state


### Append new data to existing model and retrain

In [None]:
using Flux
using JLD2

infil = pick_file()

# Load the existing model and optimizer
loaded_model, opt = JLD2.@load infil model opt

# Load new records and prepare the data
new_X_train, _ = get_heave(Data, f23_df)  # Replace with your new data fetching function
new_X_train_normalized = normalize_records(new_X_train)
new_X_train_float32 = Float32.(new_X_train_normalized)

# Combine with the previous training data (if applicable)
# You can concatenate with previous training data if desired
X_combined = hcat(old_X_train_float32, new_X_train_float32)

# Define the loss function again
loss(x) = Flux.mse(loaded_model(x), x)

# Prepare the data for training (iterating over the new combined data)
data = Iterators.repeated((X_combined,), 100)

# Train the model on the new data
Flux.train!(loss, Flux.params(loaded_model), data, opt)

outfil = infil

# Save model and optimizer and normalised wave data to a JLD2 file
jldsave(outfil; model_state, opt_state, X_train_float32)