In [None]:
using Pkg 
# Pkg.instantiate()
Pkg.activate(".")  # activate the environment for this notebook
# load the packages
using Flux, Dates, BSON, Measures, JLD2, Noise, JSON, Plots
using Printf, Statistics, ProgressMeter, ParameterSchedulers, LaTeXStrings
using ParameterSchedulers: Stateful , Optimisers
pyplot()

# LSTM

In [None]:
default(linewidth = 1)

@load "LSTM plots/results_LSTM.jld2" results train_results

# Colors per epoch
epoch_list = [100, 200, 300, 400]
color_map = Dict(e => i for (i, e) in enumerate(epoch_list))

# Function to collect mean and std of MSE across alts, skipping horizon_u=5
function collect_mean_std_over_alt_skip5(branch_results, horizon_t, epochs)
    alts = collect(keys(branch_results[horizon_t][epochs]))
    horizon_us_all = sort(collect(keys(branch_results[horizon_t][epochs][alts[1]])))
    
    # Skip horizon_u = 5
    horizon_us = filter(x -> x != 5, horizon_us_all)
    
    mean_over_alts = Float64[]
    std_over_alts  = Float64[]
    
    for h in horizon_us
        vals = [branch_results[horizon_t][epochs][alt][h][1] for alt in alts]
        push!(mean_over_alts, mean(vals))
        push!(std_over_alts, std(vals))
    end
    
    return horizon_us, mean_over_alts, std_over_alts
end

for branch in sort(collect(keys(results)))
    branch_results = results[branch]

    savepath = "plots_branch_$branch"
    isdir(savepath) || mkdir(savepath)

    for horizon_t in sort(collect(keys(branch_results)))

        plt = plot(
            title = L"Branch %$(branch), Training Horizon $s_t$ = %$(horizon_t)",
            xlabel = L"Unrolling Horizon $s_u$",
            ylabel = "Loss",
            legend = :topright
        )

        for epochs in sort(collect(keys(branch_results[horizon_t])))

            hrz_u, mse_mean, mse_std = collect_mean_std_over_alt_skip5(branch_results, horizon_t, epochs)

            plot!(
                plt, 
                hrz_u, mse_mean, 
                markersize = 3.5, marker = :diamond,
                label = L"n_{epochs}=%$epochs",
                color = color_map[epochs]
            )

            plot!(
                plt, 
                hrz_u, mse_mean, 
                markersize = 3.5, marker =:diamond,   
                label = "", 
                color = color_map[epochs],
                ribbon = mse_std, fillalpha = 0.12
            )

        end

        savefig(joinpath(savepath, "mse_vs_unroll_t_$(horizon_t)_meanAlt.png"))
    end
end

In [None]:
default(linewidth = 1, markersize = 3.5)

@load "LSTM plots/results_LSTM.jld2" results train_results

# Colors per epoch
epoch_list = [100, 200, 300, 400]
color_map = Dict(e => i for (i,e) in enumerate(epoch_list))

for branch in sort(collect(keys(train_results)))
    branch_train = train_results[branch]
    savepath = "train_plots_branch_$branch"
    isdir(savepath) || mkdir(savepath)

    # Collect horizon_ts
    horizon_ts = sort(collect(keys(branch_train)))

    # ---- Plot: Training Loss (mean ± std over alts) ----
    plt_loss = plot(
        title = L"Branch %$branch, Training Loss vs Training Horizon $s_t$",
        xlabel = L"Training Horizon $s_t$",
        ylabel = "Training Loss",
        legend = :topright
    )

    for epoch in epoch_list
        mean_loss = Float64[]
        std_loss  = Float64[]
        horizon_filtered = Int[]

        for horizon_t in horizon_ts
            if haskey(branch_train[horizon_t], epoch)
                alts = collect(keys(branch_train[horizon_t][epoch]))
                vals = [branch_train[horizon_t][epoch][alt][:loss_stat_t][1] for alt in alts]
                push!(mean_loss, mean(vals))
                push!(std_loss, std(vals))
                push!(horizon_filtered, horizon_t)
            end
        end

        plot!(
            plt_loss,
            horizon_filtered, mean_loss,
            markersize = 3.5, marker =:diamond,   
            label = L"n_{epochs}=%$epoch", 
            color = color_map[epoch],
        )

        plot!(
            plt_loss,
            horizon_filtered, mean_loss,
            markersize = 3.5, marker =:diamond,   
            label = "",
            color = color_map[epoch],
            ribbon = std_loss, fillalpha = 0.12
        )

    end

    savefig(joinpath(savepath, "train_loss_vs_horizon_meanAlt.png"))

    # ---- Plot: Training Time (mean ± std over alts) ----
    plt_time = plot(
        title = L"Branch %$branch, Computational Time vs Training Horizon $s_t$",
        xlabel = L"Training Horizon $s_t$",
        ylabel = "Computational Time [s]",
        legend = :topright
    )

    for epoch in epoch_list
        mean_time = Float64[]
        std_time  = Float64[]
        horizon_filtered = Int[]

        for horizon_t in horizon_ts
            if haskey(branch_train[horizon_t], epoch)
                alts = collect(keys(branch_train[horizon_t][epoch]))
                vals = [branch_train[horizon_t][epoch][alt][:train_time] for alt in alts]
                push!(mean_time, mean(vals))
                push!(std_time, std(vals))
                push!(horizon_filtered, horizon_t)
            end
        end

        plot!(
            plt_time,
            horizon_filtered, mean_time,
            markersize = 3.5, marker =:diamond,   
            label = L"n_{epochs}=%$epoch", yaxis=:log10, xaxis=:log10,
            color = color_map[epoch],
        )

        plot!(
            plt_time,
            horizon_filtered, mean_time,
            markersize = 3.5, marker =:diamond,   
            label = "",
            color = color_map[epoch],
            ribbon = std_time, fillalpha = 0.12,
        )
    end

    savefig(joinpath(savepath, "train_time_vs_horizon_meanAlt.png"))
end


In [None]:
@load "loss_alt_comp_storms_b1.jld2" loss  

horizons = [60, 50, 40, 30, 20]
alts = [0, 1, 2]
epochsL = [100, 200, 300, 400]
storms = [1, 2, 3]

for ep in epochsL
    
    plt = plot(
        title = "Branch 1, Unroll Loss vs Horizon (epochs = $ep)",
        xlabel = "Horizon",
        ylabel = "Loss",
        legend = :topright,
        lw = 1,
        xticks = horizons,
    )

    for alt in alts
        
        # compute mean loss across storms for each horizon
        mean_losses = [
            mean([ loss[(h, alt, ep, s)][1] for s in storms ])
            for h in horizons
        ]

        plot!(
            horizons,
            mean_losses,
            label = "alt = $alt",
            marker = :diamond,
            markersize = 3.5,
        )
    end

    display(plt)

    # Optional save
    png("loss_vs_horizon_epoch$(ep)_mean_storms_b1.png")
end


# FRNN

In [None]:
# Load
@load "matrixtrain.jld2" train_stat
@load "matrixunroll.jld2" unroll_stat
@load "bestmodel_t.jld2" best_stat_t
@load "bestmodel_u.jld2" best_stat_u

trials = size(train_stat,1)

data=load("Dataset/model_0d_lake_sep.jld2")

for branch in 1:2
    X, Y, time  = collect_data(data, branch, noise_param = 0.1)
    x0 = X[:,1]
    ext_forc = X[1+branch:end,:]
    model_r = best_stat_u[branch][1]
    outputs , mse = unroll(model_r, x0, ext_forc, Y, branch)
    plot_unroll(outputs, Y, time, branch)
    # Index of best training trial
    best_train_idx = Int(best_stat_t[branch][2][end])
    # Index of best unroll trial
    best_unroll_idx = Int(best_stat_u[branch][2][end])

    println("\nBranch $branch")
    @show best_stat_u[branch][2]
    println("Best training trial: $best_train_idx")
    println("  Train metrics  : ", train_stat[best_train_idx, branch])
    println("  Unroll metrics : ", unroll_stat[best_train_idx, branch])

    println("Best unroll trial: $best_unroll_idx")
    println("  Train metrics  : ", train_stat[best_unroll_idx, branch])
    println("  Unroll metrics : ", unroll_stat[best_unroll_idx, branch])
end

threshold = 1
fail_count = zeros(Int, 2)

# Arrays to store mean and std per branch (MSE only)
mean_train  = zeros(2)
std_train   = zeros(2)
mean_unroll = zeros(2)
std_unroll  = zeros(2)

for branch in 1:2
    p = plot(layout = (1,1), size = (500,500), right_margin = 5mm,
             titlefontsize = 15, legend_column = -1, legend = :outerbottom, 
             legendfontsize = 8)

    # valid entries based on unroll MSE only
    valid_mask = unroll_stat[:, branch] .< threshold
    fail_count[branch] = count(!, valid_mask)

    # Extract MSE only
    train_vals  = train_stat[:, branch]
    unroll_vals = unroll_stat[:, branch]

    # Filter valid trials
    train_x = (1:trials)[valid_mask]
    train_y = train_vals[valid_mask]
    unroll_x = train_x
    unroll_y = unroll_vals[valid_mask]

    # Compute statistics
    mean_train[branch]  = mean(train_y)
    std_train[branch]   = std(train_y)
    mean_unroll[branch] = mean(unroll_y)
    std_unroll[branch]  = std(unroll_y)

    # Legend text
    legend_train  = @sprintf("μ=%.6f σ=%.6f", mean_train[branch],  std_train[branch])
    legend_unroll = @sprintf("μ=%.6f σ=%.6f", mean_unroll[branch], std_unroll[branch])

    # -------- Plot (only MSE) --------
    scatter!(p, train_x, train_y,
        title = "Branch $branch — MSE",
        ylabel = "MSE",
        ylim = (-0.05, threshold + 0.05),
        markersize = 5, ma = 0.5,
        label = "Train MSE\n$legend_train")

    scatter!(p, unroll_x, unroll_y,
        markersize = 5, ma = 0.5,
        label = "Unroll MSE\n$legend_unroll")

    display(p)
end

# Print failure counts
println("Failures per branch:")
for branch in 1:2
    println("Branch $branch — MSE failures: $(fail_count[branch])  ($(fail_count[branch]/trials))")
end

