# Multi objective optimization post-processing analysis

In [None]:
using Revise
using FUSE
using JLD2
using Plots;
gr();

## Process data

### Choose run directory

In [None]:
result_path = "nominal/opt_betaN_cost__Solovev_Kr_flattop24_HTS0.1_qpol2.75_A3.5_Zeff2.0/"

### Read and/or write cache

In [None]:
if false
    
    write_cache = false
    
    cache_path = joinpath(result_path, "gen180/")

    all_dirs = filter(isdir, sort(readdir(cache_path; join=true)))
    println(length(all_dirs))

    dirs = sort(filter(x -> !isfile(joinpath(x, "error.txt")) && isfile(joinpath(x, "dd.json")), all_dirs))
    println(length(dirs))
    
    if write_cache == false
        loc = nothing
    else
        loc = dirs
    end

    IMAS.update_ExtractFunctionsLibrary!(); # to pick up any ongoing development to extract function library
    outputs = FUSE.extract(loc; filter_invalid=:cols, cache=joinpath(cache_path, "extract.csv"), read_cache=true, write_cache=write_cache);
end

### Error analysis

In [None]:
using StatsPlots

cache_path = joinpath(result_path, "opt_runs/")
all_dirs = filter(isdir, sort(readdir(cache_path; join=true)))
println(length(all_dirs))

dirs = sort(filter(x -> !isfile(joinpath(x, "error.txt")) && isfile(joinpath(x, "dd.json")), all_dirs))
println(length(dirs))

errors = FUSE.categorize_errors(all_dirs; do_plot=true, show_first_line=false)
err = errors[:other][end]
#println(read(err * "/error.txt",String))

x = Dict()
h = Dict()
ngen = (length(all_dirs)+5*256) / 256 / 2
ngen = 5
nbins = Int(ngen)
bins = StepRange(1,floor(length(all_dirs) ÷ nbins),length(all_dirs))

for key in keys(errors)
    x[key] = indexin(errors[key],sort(all_dirs))
    h[key] = []
    for (b0,b1) in zip(bins[1:end-1],bins[2:end])
        n = length(findall((x[key] .< b1) .* (x[key] .> b0)))
        h[key] = [h[key]; n]
    end
end

data = [h[key] for key in keys(errors)]
nticks = nbins
xticks = (0:1:nbins, [string(x) for x in LinRange(0,ngen,nticks)])
xtixks = 0:5

g = groupedbar(convert(Matrix{Int},reduce(hcat, data)) ./ (length(all_dirs)/nbins) * 100, 
    bar_position=:stack,
    label=String.(reduce(hcat,keys(errors))),
    xlabel="Generation",
    ylabel="Percentage of FUSE runs (%)",
    title ="Error codes",
    legend=:outerright,
    #xticks=xticks,
    left_margin = 5Plots.mm,
    bottom_margin = 8Plots.mm,
    )

display(plot(g,size=(800,450)))


## Visualize Optimzation results

### Load optimization results file

In [None]:
con = []

try
    @load joinpath(result_path, "result.jld2") result
    for c in result.convergence
        con = [con; c]
    end
catch
    display("Could not find result.jld2")
end

try
    @load joinpath(result_path, "result2.jld2") result2
    for c in result2.convergence
        con = [con; c]
    end
catch
    display("Could not find result2.jld2")
end

try
    @load joinpath(result_path, "result3.jld2") result3
    for c in result3.convergence
        con = [con; c]
    end
catch
    display("Could not find result3.jld2")
end

### Plot f, x, g evolution

In [None]:
import Metaheuristics: optimize, ECA, SMS_EMOA, SPEA2, TestProblems, pareto_front, Options, convergence
using Plots.PlotMeasures
# generate plots

fnames = [ "cost", "βn"]
xnames = ["B0 (T)", "R0 (m)", "P0 (MPa)", "Ip (MA)", "fGW", "fGWped", "Paux (MW)"]
xfac = [1, 1, 1e6, 1e6, 1, 1, 1e6]
gnames = ["Pnet (%)", "Sn (%)", "accessEC (%)", "fLH (%)", "qpol (%)", "ds03 (%)", "TF_j_margin (%)", "OH_j_margin (%)", "TF_stress_margin (%)", "OH_stress_margin (%)"]

pf = plot(layout=(length(fnames),1), size=(650,650),left_margin=15mm)
px = plot(layout=(length(xnames),1), size=(650,1500),left_margin=15mm)
pg = plot(layout=(length(gnames),1), size=(650,2000),left_margin=15mm)

for i in 1:1:length(con)
    A = pareto_front(con[i])
    B = con[i].population
    
    pop_x = reduce(hcat, [j.x for j in B])'
    pop_f = reduce(hcat, [j.f for j in B])'
    pop_g = reduce(hcat, [j.g for j in B])'
    
    for j in 1:length(fnames)
        scatter!(pf[j], repeat([i],length(pop_f[:,j])), pop_f[:,j], marker_z=pop_f[:,1], clim=(0,5), colorbar=:none, legend=:none, ylabel=fnames[j], yrange=(0,20))
    end
    
    for j in 1:length(xnames)
        scatter!(px[j], repeat([i],length(pop_x[:,j])), pop_x[:,j]/xfac[j], marker_z=pop_f[:,1], clim=(0,5), colorbar=:none, legend=:none, ylabel=xnames[j])
    end
    
    for j in 1:length(gnames)
        scatter!(pg[j], repeat([i],length(pop_g[:,j])), pop_g[:,j], marker_z=pop_f[:,1], clim=(0,5), colorbar=:none, legend=:none, ylabel=gnames[j], yrange=(-5,5))
    end

end

if true
    display(pf)
    display(px)
    display(pg)
end

### Define function for scatter plot of ith generation

In [None]:
function scatter_gen(
        con,igen,xlabel,ylabel,
        xrange,yrange,
        fnames,xnames,gnames)
    
    pop = Dict()
    pf = Dict()
    pfi = []

    # assemble function values
    for (i,fname) in enumerate(fnames)
        str = split(fname, " ")
        pop[str[1]] = [x.f[i] for x in con[igen].population]
        pf[str[1]] = pareto_front(con[igen])[:,i]
    end

    # assemble input values
    for (i,xname) in enumerate(xnames)
        str = split(xname, " ")
        pop[str[1]] = [x.x[i] for x in con[igen].population]
    end

    # assemble constraint values
    for (i,gname) in enumerate(gnames)
        str = split(gname, " ")
        pop[str[1]] =  [x.g[i] for x in con[igen].population]
    end

    # re-scale constraint values
    @. pop["Pnet"] = min_Pelectric*(1-pop["Pnet"])
    @. pop["fLH"] = min_fLH*(1-pop["fLH"])
    @. pop["accessEC"] = max_accessEC*(1+pop["accessEC"])
    @. pop["Sn"] = max_Sn*(1+pop["Sn"])
    @. pop["qpol"] = max_qpol*(1+pop["qpol"])
    @. pop["ds03"] = max_ds03*(1+pop["ds03"])
    @. pop["TF_j_margin"] = min_TF_j_margin + pop["TF_j_margin"]
    @. pop["OH_j_margin"] = min_OH_j_margin + pop["OH_j_margin"]
    @. pop["TF_stress_margin"] = min_TF_stress_margin + pop["TF_stress_margin"]
    @. pop["OH_stress_margin"] = min_OH_stress_margin + pop["OH_stress_margin"]

    # make array of pareto-optimal indices
    for pf in pf[fnames[1]]
        index = findfirst(isequal(pf), pop[fnames[1]])
        pfi = [pfi;[index]]
    end

    # assemble pareto-optimal input values
    for (i,xname) in enumerate(xnames)
        str = split(xname, " ")
        pf[str[1]] = pop[str[1]][pfi]
    end

    # assemble pareto-optimal constraint values
    for (i,gname) in enumerate(gnames)
        str = split(gname, " ")
        pf[str[1]] = pop[str[1]][pfi]
    end

    ## PLOT

    s1 = scatter(pop[xlabel], pop[ylabel], color=:grey, label="Population", alpha=0.7)
    s2 = scatter!(s1, pf[xlabel], pf[ylabel], color=:red, label="Pareto-optimal")
    s = plot(s2, size=(450, 450), title="Generation: "*string(igen), xlabel=xlabel, ylabel=ylabel, xrange=xrange, yrange=yrange) 
    
    return s
    
end

### Make gif of population evolution

In [None]:
min_Pelectric = 200.0 # MW
min_fLH = 1.0
max_Sn = 1.5
max_qpol = 2.75e3 # MW/m^2
max_ds03 = 1.0
max_accessEC = 1.0
min_TF_j_margin = 1.5
min_OH_j_margin = 1.5
min_TF_stress_margin = 1.0
min_OH_stress_margin = 1.0


xlabel = "βn"
xrange = (0,6)

ylabel = "cost"
yrange = (0.0,10)

gen_list = 1:1:length(con)
gen_list = vcat(gen_list,repeat([length(con)],20))

a = @animate for igen in gen_list

    s = scatter_gen(con,igen,xlabel,ylabel,xrange,yrange,fnames,xnames,gnames)
    
end

g = gif(a, joinpath(result_path, xlabel*"_"*ylabel*".gif"), fps=10)
display(g)

### Static plots

In [None]:
xlabel = "βn"
ylabel = "qpol"

xrange = :missing
yrange = :missing

s = scatter_gen(con,gen_list[end],xlabel,ylabel,xrange,yrange,fnames,xnames,gnames)

scatter!(s,outputs[:,"βn_MHD"],outputs[:,"qpol"],markershape=:+,color=:black)

