In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle as pkl
import sys
import torch
from typing import Optional

from mlcg.nn import SumOut

sys.path.insert(0, os.path.join("../.."))
from input_generator.prior_gen import PriorBuilder
from prior_tools import symmetrized_keys_generator, optimal_offset, prior_evaluator

In [None]:
plots_path = "./temp_figs"
if not os.path.isdir(plots_path):
    os.mkdir(plots_path)
# select the temperature for your simulations
temperature = 300
kB = 0.0019872041  # kcal/(mol⋅K)
beta = 1 / (temperature * kB)

In [None]:
prior_model = torch.load("./prior_model_example.pt")

In [None]:
with open("./prior_builders_example.pck", "rb") as f:
    prior_builders = pkl.load(f)

In [None]:
for idx, prior_bldr in enumerate(prior_builders):
    print(f"Index: {idx}, prior builder: {prior_bldr.name}")

In [None]:
def basic_prior_visual_inspection(
    prior_bldr: PriorBuilder,
    prior_model: SumOut,
    beta: float,
    xlims = tuple[float,float],
    ylims = tuple[float,float],
    keys_per_fig: int = 1,
    savedir: Optional[str] = None,
    n_max_figs: Optional[int] = None,
):
    r"""
    Function to plot the free energy distribution compared to the prior curve 
    
    This function will plot the comparison between the free energy distribution
    and the prior fit that should match it for every bead combination that was 
    found for the fit. The purpose of this function is to exemplify how to 
    extract the histograms fro the prior builder and how to correctly 
    compare this to the prior.

    Normally the bead combinations are too many to be analyzed on a one-by-one basis.
    This function can easily modified to filter the histograms and only plot the ones
    that fail according to some criteria 


    Parameters
    ----------

    prior_bldr:
        PriorBuilder object from which we will get the histograms
    prior_model:
        Prior as a mlcg.nn.SumOut object. It must contain a module whose name
        coincides with the `prior_bldr` object
    beta:
        the inverse temperature (1/(KB*T)) in the energy units used by the prior model
    xlims: 
        x-axis limits of the plots to be generated. This value depends on the prior type
        to analyse and CG resolution.
    xlims: 
        y-axis limits of the plots to be generated. same restrictions apply than to the `xlims`
    keys_per_fig:
        number of bead combinations to plot in every figure. Useful for comparing combinations
        in the same plot
    savedir:
        dictionary where to save every plot generated
    n_max_figs:
        maximum number of figures to plot.
    
    """
    name = prior_bldr.name
    # getting all  the bead combinations for which we collected statistics
    keys = prior_bldr.histograms.data[name].keys()
    # generating the list of possible bead combinations respecting symmetries
    unique_keys = symmetrized_keys_generator(len(list(keys)[0]))
    # intersecting the keys to avoid duplicating plots
    keys_list = [key for key in keys if key in unique_keys]
    print(f"found {len(keys_list)} keys to plot for prior {name}")
    # get the prior module corresponding to this model
    prior_module = prior_model.models[name].model
    # this range controls how many different figures will we plot
    n_figures = len(keys_list) // keys_per_fig
    for i in range(n_figures):
        if n_max_figs is not None and  i >= n_max_figs:
            break
        fig, ax = plt.subplots()
        # how many curves do we want in every plot
        end = (i + 1) * keys_per_fig
        if end > len(keys_list):
            end = len(keys_list)
        for j in range(i * keys_per_fig, end):
            key = keys_list[j]
            # getting the histogram data
            centers = prior_bldr.histograms.bin_centers
            data = prior_bldr.histograms.data[name][key]
            # converting probability into fre energy
            data = -1 * np.log(data) / beta
            # evaluate the prior in the bin centers from the histogram
            prior_fit_values = prior_evaluator(prior_module, key, centers)
            # finding the optimal offset between fit and data
            offset = optimal_offset(np.array(prior_fit_values), data)
            # evaluation range to see the prior. It is good to check
            # that the prior is still extrapolating with an acceptable
            # behaviour
            eval_range = torch.linspace(xlims[0],xlims[1],201)
            prior_fit_values_overlaped = prior_evaluator(prior_module, key, eval_range) + offset
            # plot arrays with relevant labels
            ax.plot(centers, data, lw=3, alpha=0.6, label=f"{key} data")
            ax.plot(eval_range, prior_fit_values_overlaped, lw=1, alpha=0.6, label=f"{key} fit")
        # make the plot pretty
        ax.legend()
        ax.grid()
        ax.set_xlabel(f"{name} value")
        ax.set_ylabel("FE/kT")
        ax.set_ylim(ylims[0],ylims[1])
        ax.set_title(f"Prior fit: {name}")
        if savedir is not None:
            plt.savefig(f"{savedir}/{name}_{i:03}.png", dpi=200)
        plt.show()
        print(i)
        

In [None]:
basic_prior_visual_inspection(
    prior_bldr = prior_builders[0],
    prior_model=prior_model,
    beta=beta,
    xlims=(2.5,4.5),
    ylims=(-10.0,20.0),
    keys_per_fig = 5,
    n_max_figs=3,
    savedir=plots_path,
)

In [None]:
basic_prior_visual_inspection(
    prior_bldr = prior_builders[1],
    prior_model=prior_model,
    beta=beta,
    xlims=(-1.1,1.1),
    ylims=(-10,15),
    keys_per_fig = 1,
    n_max_figs=2,
    savedir=plots_path,
)

In [None]:
basic_prior_visual_inspection(
    prior_bldr = prior_builders[2],
    prior_model=prior_model,
    beta=beta,
    xlims=(2,7),
    ylims=(-10,15),
    keys_per_fig = 3,
    n_max_figs=1,
    savedir=plots_path,
)