In [None]:
"""
Plot diverse uncertainty parameters

Created May 2023

@author: vrath
"""

In [None]:
import os
import sys
from sys import exit as error
from time import process_time
from datetime import datetime
import warnings
import random
import functools
from cycler import cycler

In [None]:
import numpy
import scipy.interpolate
import scipy.linalg

In [None]:
import matplotlib
import matplotlib.pyplot
import matplotlib.ticker
import matplotlib.axis

In [None]:
AEMPYX_ROOT = os.environ["AEMPYX_ROOT"]
mypath = [AEMPYX_ROOT+"/aempy/modules/", AEMPYX_ROOT+"/aempy/scripts/"]
for pth in mypath:
    if pth not in sys.path:
        sys.path.insert(0,pth)

In [None]:
from version import versionstrg

In [None]:
import util
import aesys
import inverse
import viz

In [None]:
warnings.simplefilter(action="ignore", category=FutureWarning)
cm = 1/2.54

In [None]:
OutInfo = True
AEMPYX_DATA = os.environ["AEMPYX_DATA"]

In [None]:
version, _ = versionstrg()
titstrng = util.print_title(version=version, fname=__file__, out=False)
print(titstrng+"\n\n")

In [None]:
now = datetime.now()

In [None]:
"""
System related settings.
Data transformation is now allowed with three possible options:
DataTrans   = 0           raw data
            = 1           natural log of data
            = 2           asinh transformation
An error model is applied for the raw data, which is
mixed additive/multiplicative. in case of data transformation,
errors are also transformed.
"""
AEM_system = "aem05"
# AEM_system = "genesis"
if "aem05" in AEM_system.lower():
    FwdCall,NN, _, _, Pars, = aesys.get_system_params(System=AEM_system)
    nL = NN[0]
    ParaTrans = 1
    DataTrans = 0
    DatErr_add =  50.
    DatErr_mult = 0.03
    data_active = numpy.ones(NN[2], dtype="int8")
    CompDict = Pars[3]
    CompLabl = list(CompDict.keys())
    print(CompLabl)
    # Pars[0] = numpy.round(Pars[0],1)/1000.

In [None]:
if "genes" in AEM_system.lower():
    FwdCall, NN, _, _, Pars, = aesys.get_system_params(System=AEM_system)
    nL = NN[0]
    ParaTrans = 1
    DataTrans = 0
    DatErr_add = 100.
    DatErr_mult = 0.01
    data_active = numpy.ones(NN[2], dtype="int8")
    data_active[0:11]=0  # only vertical component
    # data_active[10:11]=0  # Vertical + "good" hoizontals"
    CompDict = Pars[3]
    CompLabl = list(CompDict.keys())

In [None]:
Plotlist = ["model", "sens", "cov", "cor", "respar", "resdat"]

Plotlist = [ "respar"]
Plotlist = [ "model"]

In [None]:
Plotlist = [item.lower() for item in Plotlist]

In [None]:
if "model" in Plotlist:
    print("Model will be plotted")
    err = "lsq"
    err = "msq"
    err = "jsq"  
    Modelcolor = ["b", "r", "r", ]
    Modellines = ["-", ":", ":" ]    
    Modelwidth = [ 1,  1,  1,]
    ModelLimits = [1., 10000.] 
    DepthLimits = [0., 100.]

In [None]:
if "sens" in Plotlist:
    print("Sensitivities will be plotted")
    whichsens = ["raw","cov", "euc" , "cum"] 
    print("   Sensitivity type is ", str(whichsens))
    
    Senscolor = ["b", "g", "r", "m", "y"]
    Senslines =  ["-", "-", "-", "-", "-"]
    Senswidth = [ 1, 1,  1, 1, 1.]
    SensLimits = [0.001, 2.]
    DepthLimits = [0., 100.]

In [None]:
if "respar" in Plotlist:
    print("Parameter resolution will be plotted")
    whichspread = "fro"   #, too"", "euc", "mic"
    print("Spread type is "+whichspread)
    PhysAxes = True
    NoHalfspace = True

In [None]:
if "resdat" in Plotlist:     
    print("Data resolution will be plotted")
    whichspread = "fro"   #, too"", "euc", "mic"
    print("   Spread type is "+whichspread)
    PhysAxes = True

In [None]:
if "cov" in Plotlist:
    print("Posterior covariance matrix will be plotted")
    NoHalfspace = True

In [None]:
if "cot" in Plotlist:
    print("Parameter correlation matrix will be plotted")
    NoHalfspace = True

In [None]:
if "jac" in Plotlist:
    print("Jacobian matrix will be plotted")
    NoHalfspace = True

In [None]:
# Sample = "random"
# Sample = "distance list"
Sample = "distance list"
if "rand" in Sample:
    NSamples = 1

In [None]:
elif "list" in Sample:
    if "pos" in Sample:
        Samplist = [100, 200]
    if "dis" in Sample:
        Distlist = [ 1500.]

In [None]:
"""
input format is "npz"
"""
# GENESIS
# InModDir =  "/home/vrath/DuyguPoster/TD_uncert/"
# SearchStrng ="*901*results.npz"
# PDFCName = "GENESIS_FL901_Uncert-Catalog.pdf"
# AEM05
# InModDir =  AEMPYX_DATA + "/Projects/InvParTest/proc_delete_PLM3s/results_diffop/"
InModDir =  "/home/vrath/DuyguPoster/FD_uncert/"
SearchStrng ="A1*30*results.npz"

In [None]:
if not InModDir.endswith("/"): InModDir=InModDir+"/"
print("Models read from dir: %s " % InModDir)
# FileList = "set" #"search"

In [None]:
FileList ="search" 
if "search" in FileList.lower():
    how = ["search", SearchStrng, InModDir]
    # how = ["read", FileList, InModDir]
    print("Method is ", how )
    mod_files = util.get_data_list(how, out= True, sort=True)

FileList ="set" 
mod_files = ["A1*results.npz"]

In [None]:
ns = numpy.size(mod_files)
if ns ==0:
    error("No modfiles set!. Exit.")

In [None]:
print("Filelist:")
print(mod_files)

In [None]:
nhorz = 3
nvert = 2
PlotSize = [8.5*cm, 8.5*cm]

In [None]:
"""
Plot formats are "".png", ".pdf", or any other
format matplotlib allows.
"""
PlotFormat = [".pdf", ".png"] #".png", ".pdf",]
# PlotDir = AEMPYX_DATA+"/ClaraUncert/plots/"
# PlotDir = InModDir+"/Lvar1/" 
PlotDir = InModDir
print("Plots written to dir: %s " % PlotDir)

In [None]:
if not os.path.isdir(PlotDir):
    print("File: %s does not exist, but will be created" % PlotDir)
    os.mkdir(PlotDir)

In [None]:
PDFCatalog = True
if ".pdf" in PlotFormat:
    PDFCatName = "AEM05_F11379_Uncert-Catalog.pdf"

In [None]:
else:
    error(" No pdfs generated. No catalog possible!")
    PDFCatalog = False

In [None]:
"""
Determine graphical parameter.
=> print(matplotlib.pyplot.style.available)
"""
FilesOnly = False
matplotlib.pyplot.style.use("seaborn-v0_8-paper")
matplotlib.rcParams["figure.dpi"] = 600
matplotlib.rcParams["axes.linewidth"] = 0.5
matplotlib.rcParams["savefig.facecolor"] = "none"
matplotlib.rcParams["savefig.transparent"] = True
matplotlib.rcParams["savefig.bbox"] = "tight" 
Fontsize = 6
Labelsize = Fontsize
Titlesize = 6
Fontsizes = [Fontsize, Labelsize, Titlesize]

In [None]:
# Markersize = 4
FigSize = [8.5*cm, 8.5*cm]

In [None]:
"""
https://matplotlib.org/stable/tutorials/colors/colormaps.html
"""
ColorMapResMat="seismic"
ColorMapCovMat="seismic"
ColorMapCorMat="seismic"
ColorMapJacMat="jet"

In [None]:
Grey20 = (0.2, 0.2, 0.2)
Grey50 = (0.5, 0.5, 0.5)
# Lines = (cycler("linewidth", [1.])
#          * cycler("linestyle", ["-", "--", ":", "-."])
#          * cycler("color", ["r", "g", "b", "m"]))

In [None]:
if FilesOnly:
    matplotlib.use("cairo")

In [None]:
if PDFCatalog:
    pdf_list = []

In [None]:
for filein in mod_files:
    start = process_time()

    modfile = InModDir + filein
    ctrfile = modfile.replace("_results.npz", "_ctrl.npz")
    fnam, ext = os.path.splitext(os.path.basename(modfile))

    print("\nResults read from: %s" % modfile)
    results = numpy.load(modfile, allow_pickle=True)

    print("\nCtrl read from: %s" % modfile)
    control = numpy.load(ctrfile, allow_pickle=True)    
    Runtyp = control["inversion"][0]
    Regfun = control["inversion"][1]
    OptStrng = "Opts: "+Runtyp+"|"+Regfun
    
    fline = results["fl_name"]
    site_x = results["site_x"]
    site_y = results["site_y"]
    site_z = results["site_dem"]

    m_act = results["mod_act"]
    m_ref = results["mod_ref"]

    
    site_mod = results["site_modl"]
    site_err = results["site_merr"]
    site_sns = results["site_sens"]
    site_rms = results["site_nrms"]

    site_dact = results["dat_act"]
    site_dobs = results["site_dobs"]
    site_dcal = results["site_dcal"]
    site_derr = results["site_derr"]
    site_rms = results["site_nrms"]

    site_jac= results["site_jacd"]
    site_cov= results["site_pcov"]
    
    nlyr = inverse.get_nlyr(m_ref)
    dz = m_ref[6*nlyr:7*nlyr-1]

    zn = inverse.set_znodes(dz)    
    zm = inverse.set_zcenters(dz)
    DepthN = zn 
    DepthC = numpy.append(zm, 999.9)
    LayThk = numpy.append(dz, 9999.)
    
    """
    construct site_list
    """
    site_x = site_x - site_x[0]
    site_y = site_y - site_y[0]
    site_r = numpy.sqrt(numpy.power(site_x, 2.0) + numpy.power(site_y, 2.0))
    
    site_list = []
    if "rand" in Sample:
        site_list = random.sample(range(len(site_x)), NSamples)

    elif "list" in Sample:
        if "posi" in Sample:
            site_list = Samplist
        if "dist" in Sample:
            for nid in numpy.arange(len(Distlist)):
                nds = (numpy.abs(Distlist[nid] - site_r)).argmin()
                site_list.append(nds)
    else:
        site_list = numpy.arange(len(site_x))
                

        
    for isite in site_list:

        # calculatio
        # generalized inverse
        npar = numpy.sum(m_act)
        ndat = numpy.sum(site_dact[isite,:])
        cov = site_cov[isite,:].reshape((npar,npar))
        jac = site_jac[isite,:].reshape((ndat,npar))
        # print(cov.shape)
        # print(jac.shape)
        dcal = site_dcal[isite,:]
        dact = site_dact[isite,:]
        dcal  = inverse.extract_dat(D=dcal, d_act=dact)
        scal = numpy.diag(1./dcal)

        v = numpy.sqrt(1./numpy.diag(cov))
        cor = cov*numpy.outer(v,v)
  
        # sensitivities
        sens = []
        
        sens0 = inverse.calc_sensitivity(Jac=jac, UseSigma=True, Type = "raw") #[:-1]
        sens0 = inverse.transform_sensitivity(S=sens0, V=LayThk, 
                                              Transform=[" val","max"])
                                              # Transform=[" val","max", "sqr"]) 
        sens.append(numpy.abs(sens0))
        
        sens1 = inverse.calc_sensitivity(Jac=jac, UseSigma=True, Type = "cov") #[:-1]            
        sens1 = inverse.transform_sensitivity(S=sens1, V=LayThk, 
                                              Transform=["max"])
                                              # Transform=[" val","max", "sqr"]) 
        if NoHalfspace: 
            sens1 = sens1[:-1] 
        sens.append(numpy.abs(sens1))
        
        sens2 = inverse.calc_sensitivity(Jac=jac, UseSigma=True, Type = "euc") #[:-1]   
        sens2 = inverse.transform_sensitivity(S=sens2, V=LayThk,  
                                              Transform=[" max", "sqr"]) 
                                              # Transform=[" val","max", "sqr"]) 
        if NoHalfspace: 
            sens2 = sens2[:-1] 
        sens.append(numpy.abs(sens2))

        sens3 = inverse.calc_sensitivity(Jac=scal@jac, UseSigma=True, Type = "cum") 
        sens3 = inverse.transform_sensitivity(S=sens3, V=LayThk, 
                                             Transform=["max"]) 
                                             # Transform=[" val","max", "sqr"]) 
        if NoHalfspace: 
            sens3 = sens3[:-1]                                            
        sens.append(numpy.abs(sens3))        
        sens.pop(0)
        
       
        # parameter resolution matrix & spread(s)
        
        gi =  cov@jac.T
        
        rm = gi@jac
        nm = numpy.sum(rm.diagonal())
     
        _, mspread0 = inverse.calc_model_resolution(J=jac, G=gi, 
                                                    Spread=["frob"])
        _, mspread1 = inverse.calc_model_resolution(J=jac, G=gi, 
                                                    Spread=["toomey"])
        _, mspread2 = inverse.calc_model_resolution(J=jac, G=gi, 
                                                    Spread=["miller"])
    
    
        rd =  jac@gi
        nd = numpy.sum(rd.diagonal())
            
                

 
        fl = str(numpy.around(fline,2)).replace(".","-")
        if "dist" in Sample: 
            PlotFile = "Uncert_FL"+fl+"_site"+str(numpy.rint(site_r[isite]))+"m"
            PlotTitle= "Uncert, FL"+fl+" site = "+str(numpy.rint(site_r[isite]))+" m"
        else:   
            PlotFile = "Uncert_FL"+fl+"_site"++str(isite)
            PlotTitle= "Uncert, FL"+fl+" site = "+str(isite)                

        fig, ax = matplotlib.pyplot.subplots(nvert,nhorz,
                                      figsize=(nhorz*PlotSize[0], nvert*PlotSize[1]),
                                      gridspec_kw={
                                          "height_ratios": [1., 1.],
                                          "width_ratios": [1., 1., 1.]})

        fig.suptitle(PlotTitle+OptStrng, fontsize=Fontsizes[2])
       
        if "model" in Plotlist:
            
           
            model = site_mod[isite, :]          
            error = site_err[isite, :]
            val = numpy.log(model)
            errm = numpy.exp(val-error)
            errp = numpy.exp(val+error)
            model = [model, errm, errp]
            
            viz.plot_depth_prof(
                    ThisAxis=ax[0,0],
                    XScale = "log",
                    PlotType = "steps filled",                    
                    Depth = [zn],                    
                    Params = [model],
                    Partyp = "model",
                    DLabel = "depth (m)",
                    PLabel = "resistivity (Ohm m)",
                    Legend = [], 
                    Linecolor=Modelcolor,
                    Linetypes=Modellines,
                    Linewidth=Modelwidth,                    
                    Fillcolor = [Grey50],
                    Fontsizes=Fontsizes,
                    PLimits = ModelLimits,
                    DLimits = DepthLimits,
                    PlotStrng="nRMS = "+str(numpy.around(site_rms[isite][0],2)),
                    StrngPos=[0.05,0.05])          
            
        if "sens" in Plotlist:
            

            viz.plot_depth_prof(
                    ThisAxis=ax[1,0],
                    Depth = [zn],
                    Partyp = "sens",
                    Params = [sens],                    
                    DLabel = "depth (m)",
                    PLabel = "sensitivity (-)",
                    Legend = ["coverage", "euclidean","cumulative"],    #  "cummulative"
                    XScale = "log",
                    PlotType = "steps",                    
                    Linecolor=Senscolor,
                    Linetypes=Senslines,
                    Linewidth=Senswidth,
                    Fillcolor = [Grey50],
                    Fontsizes = Fontsizes,
                    PLimits = SensLimits,
                    DLimits = DepthLimits,
                    PlotStrng="", #Formula, #"", #"Error: mult="+str(DatErr_mult)+" add="+str(DatErr_add),
                    StrngPos=[0.05,0.05])
            
                
      
            

        if "cov" in Plotlist:

            Matrix = cov
            
            xticks = numpy.arange(nlyr)[0:-1:5]
            xticklabels = xticks.astype(str)
            yticks = xticks
            if PhysAxes:
                yticklabels = numpy.rint(DepthC[yticks]).astype(int).astype(str)
                AxLabels = [" layer #"," depth (m)"]
            else:
                yticklabels = yticks.astype(str)
                AxLabels =  ["layer #","layer #"]    
            Aspect = "equal"
            AxTicks = [xticks, yticks]
            AxTickLabels =  [xticklabels, yticklabels]
            TickStr=["", ""]    
            
            viz.plot_matrix(
                 ThisAxis=ax[0,1],
                 Matrix=Matrix,
                 FigSize=FigSize,
                 ColorMap=ColorMapCovMat,
                 TickStr=TickStr,
                 AxLabels=AxLabels,
                 AxTicks=AxTicks ,                    
                 AxTickLabels=AxTickLabels,
                 Aspect =Aspect,
                 Fontsizes=Fontsizes,
                 PlotStrng="p-covar",
                 StrngPos=[0.05,0.05])
                

        if "cor" in Plotlist:


            Matrix = cor
            
            xticks = numpy.arange(nlyr)[0:-1:5]
            xticklabels = xticks.astype(str)
            yticks = xticks
            if PhysAxes:
                yticklabels = numpy.rint(DepthC[yticks]).astype(int).astype(str)
                AxLabels = [" layer #"," depth (m)"]
            else:
                yticklabels = yticks.astype(str)
                AxLabels =  ["layer #","layer #"]    

            AxTicks = [xticks, yticks]
            AxTickLabels =  [xticklabels, yticklabels]
            TickStr=["", ""]

   
            viz.plot_matrix(
                 ThisAxis=ax[1,1],
                 Matrix=Matrix,
                 FigSize=FigSize,
                 ColorMap=ColorMapCorMat,
                 TickStr=TickStr,
                 AxLabels=AxLabels,
                 AxTicks=AxTicks ,                    
                 AxTickLabels=AxTickLabels,
                 Fontsizes=Fontsizes,
                 Aspect =Aspect,
                 PlotStrng="p-cor",
                 StrngPos=[0.05,0.05])
                
   

        if "respar" in Plotlist:

            Matrix = rm 
            if NoHalfspace:
                Matrix = Matrix[:-1,:-1]
            
            Np = numpy.sum(numpy.diag(rm))
            PlotStrng = " Npar = "+numpy.around(Np,1).astype(str)
            StrngPos=[0.05,0.1]
            
            xticks = numpy.arange(nlyr)[0:-1:5]
            xticklabels = xticks.astype(str)
            yticks = xticks
            if PhysAxes:
                yticklabels = numpy.rint(DepthC[yticks]).astype(int).astype(str)
                AxLabels = [" layer #"," depth (m)"]
            else:
                yticklabels = yticks.astype(str)
                AxLabels =  ["layer #","layer #"]
            

            Aspect = "equal"            
            AxTicks = [xticks, yticks]
            AxTickLabels =  [xticklabels, yticklabels]
            TickStr=["", ""]                                    

            viz.plot_matrix(
                ThisAxis=ax[0,2],
                Matrix=Matrix,
                FigSize=FigSize,
                ColorMap=ColorMapResMat,
                AxLabels=AxLabels,
                AxTicks=AxTicks,
                AxTickLabels=AxTickLabels,
                TickStr=TickStr,
                Fontsizes=Fontsizes,
                Aspect =Aspect,
                PlotStrng=PlotStrng,
                StrngPos=StrngPos)

        if "resdat" in Plotlist:

                
            Matrix = rd
            
            Nd = numpy.sum(numpy.diag(rd))
            PlotStrng = " Ndat = "+numpy.around(Nd,1).astype(str)
            StrngPos=[0.05,0.1]
                
            if "aem05" in AEM_system:
                Axlabels =  ["data #","data #"]
                xticks = numpy.arange(8)
                xticklabels = xticks.astype(str)
                yticks =  xticks
                yticklabels = xticklabels
   
                if PhysAxes: 
                    AxLabels =  ["data #"," frequency (kHz)"]  
                    pars = Pars[0]*1.e-3
                    vals = numpy.concatenate((pars, pars))
                    iticks = numpy.array([0, 2, 4 , 6])                     
                    yticks = yticks[iticks] 
                    yticklabels = numpy.round(vals,2).astype(str)[iticks]


                    
            if "genes" in AEM_system:
                AxLabels =  ["data #","data #"]
                xticks = numpy.arange(11) 
                xticklabels = xticks.astype(str)
  
                yticks = xticks
                yticklabels = xticklabels
                
                if PhysAxes:
                    AxLabels =  ["data #"," window center (1e-6 s)"]
                    vals = Pars[0]*1000.
                    iticks = numpy.array([0, 2, 4, 6, 8, 10])  
                    yticks = yticks[iticks]
                    yticklabels = numpy.round(vals,1).astype(str)[iticks]
                                
            Aspect = "equal"         
            AxTicks = [xticks, yticks]
            AxTickLabels =  [xticklabels, yticklabels]
            TickStr=["", ""]   


            viz.plot_matrix(
                ThisAxis=ax[1,2],
                Matrix=Matrix,
                FigSize=FigSize,
                ColorMap=ColorMapResMat,
                AxLabels=AxLabels,
                AxTicks=AxTicks,
                AxTickLabels=AxTickLabels,
                TickStr=TickStr,
                Fontsizes=Fontsizes,
                Aspect =Aspect,
                PlotStrng=PlotStrng,
                StrngPos=StrngPos)
                                 
        if "jac" in Plotlist:


            Matrix = jac.T
            if NoHalfspace:
                Matrix  = Matrix[:-1]
            
            yticks = numpy.arange(nlyr)[0:-1:5]
            
            if PhysAxes:
                yticklabels = numpy.rint(DepthC[yticks]).astype(int).astype(str)
                ylabel = "depth (m)"
            else:
                yticklabels = yticks.astype(str)
                ylabel =  "layer #"    
                
                            
            if "aem05" in AEM_system:
                xlabel = "data #"
                xticks = numpy.arange(8)
                xticklabels = xticks.astype(str)
   
                if PhysAxes: 
                    xlabel = "frequency (kHz)"
                    pars = Pars[0]*1.e-3
                    vals = numpy.concatenate((pars, pars))
                    iticks = numpy.array([0, 2, 4 , 6])                     
                    xticks = xticks[iticks] 
                    xticklabels = numpy.round(vals,2).astype(str)[iticks]


                    
            if "genes" in AEM_system:
                xlabel =  "data #"
                xticks = numpy.arange(11) 
                xticklabels = xticks.astype(str)
  
               
                if PhysAxes:
                    xlabel = "window center (1e-6 s)"
                    vals = Pars[0]*1000.
                    iticks = numpy.array([0, 2, 4, 6, 8, 10]) 
                    xticks = xticks[iticks] 
                    xticklabels = numpy.round(vals,1).astype(str)[iticks]
                                    
            
         
            Aspect = "auto" #aspect
            AxLabels = [xlabel, ylabel]
            AxTicks  = [xticks, yticks]
            AxTickLabels =  [xticklabels, yticklabels]
            TickStr=["", ""]    
 
            viz.plot_matrix(
                  ThisAxis=None,
                  Matrix=Matrix,
                  FigSize=FigSize,
                  ColorMap=ColorMapJacMat,
                  TickStr=TickStr,
                  AxLabels=AxLabels,
                  AxTicks=AxTicks ,                    
                  AxTickLabels=AxTickLabels,
                  Fontsizes=Fontsizes,
                  Aspect =Aspect,
                  PlotStrng="jacobian",
                  StrngPos=[0.05,0.05])
         
    matplotlib.pyplot.tight_layout()
    
    for F in PlotFormat:
        matplotlib.pyplot.savefig(PlotDir+PlotFile+F)
    
    if PDFCatalog:
        pdf_list.append(PlotDir+PlotFile+".pdf")    

In [None]:
if PDFCatalog:
    viz.make_pdf_catalog(PDFList=pdf_list, FileName=PlotDir+PDFCatName)
    # print(str(len(pdf_list))+" collected to "+PlotDir+PDFCName)