In [1]:
import base64
import io

In [2]:
## https://github.com/BrownParticleAstro/dmtools/blob/versioned/basecode/main/app/libraries/plot_operations/v0/plot_operations.py

In [3]:
# for plotting
import matplotlib as mpl
mpl.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec

# for legend
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import numpy as np

# for axes
from matplotlib.ticker import LogFormatterMathtext

In [4]:
def get_current_plots_data():
    import json
    
    # Open the file and load its contents into a Python variable
    with open('./data/current_plot.json', 'r') as f:
        json_data = json.load(f)

    return json_data

# Now 'data' is a Python dict (or list, depending on the JSON structure)
#print(data)

In [5]:
#https://github.com/BrownParticleAstro/dmtools/blob/versioned/basecode/main/app/libraries/migration/v0/trace.py

In [6]:
def get_color_mpl(color_in):
    trace_color = {}

    # Define the base color and alpha based on the input
    if color_in in ('k', 'black', 'Blk'):
        trace_color.update({'color': 'black', 'alpha': 1})
    elif color_in in ('r', 'red', 'Red', 'dkr'):
        trace_color.update({'color': 'red', 'alpha': 1})
    elif color_in in ('dkg', 'DkG', 'green', 'Grn'):
        trace_color.update({'color': 'green', 'alpha': 1})
    elif color_in in ('ltg', 'LtG'):
        trace_color.update({'color': 'green', 'alpha': 0.5})
    elif color_in in ('ltr', 'LtR'):
        trace_color.update({'color': 'red', 'alpha': 0.5})
    elif color_in == 'b':
        trace_color.update({'color': 'blue', 'alpha': 1})
    elif color_in in ('ltb', 'LtB', 'Blue','Blu','DkB'):
        trace_color.update({'color': 'blue', 'alpha': 0.5})
    elif color_in in ('c', 'Cyan', 'cyan'):
        trace_color.update({'color': 'cyan', 'alpha': 1})
    elif color_in in ('g', 'grey'):
        trace_color.update({'color': 'grey', 'alpha': 1})
    elif color_in in ('g10', 'g20', 'g30', 'g40', 'g50', 'g60', 'g70', 'g80', 'g90', 'G60'):
        trace_color.update({'color': 'grey'})
        try:
            shade = int(color_in[1:]) / 100
            trace_color.update({'alpha': shade})
        except:
            trace_color.update({'alpha': 1})
    elif color_in in ('m', 'magenta', 'Mag'):
        trace_color.update({'color': 'magenta', 'alpha': 1})
    elif color_in in ('y', 'yellow','Yel'):
        trace_color.update({'color': 'yellow', 'alpha': 1})
    elif color_in in ('w', 'white'):
        trace_color.update({'color': 'white', 'alpha': 1})
    else:
        trace_color.update({'color': 'black', 'alpha': 1})

    return trace_color

    
def get_style_mpl(color_in, style_in):
    trace_color = get_color_mpl(color_in)
    trace_style = trace_color.copy()  # Start by copying the color attributes

    # Set the style based on the style_in input
    if style_in in ('dot', 'dotted', 'Dot'):
        trace_style.update({
            'linestyle': ':',
            'linewidth': 1,
            'marker': None,
            'markersize': 0,
            'alpha': 1,
            'fill': False,
            'style': 'dot'
        })
    elif style_in in ('dash', 'Dash'):
        trace_style.update({
            'linestyle': '--',
            'linewidth': 1,
            'marker': None,
            'markersize': 0,
            'alpha': 1,
            'fill': False,
            'style': 'dash'
        })
    elif style_in in ('fill', 'Fill'):
        trace_style.update({
            'linestyle': None,
            'linewidth': 0,
            'marker': None,
            'markersize': 0,
            'alpha': 0.3,
            'fill': True,
            'style': 'fill'
        })
    elif style_in in ('Line', 'line', 'lines'):
        trace_style.update({
            'linestyle': '-',
            'linewidth': 1,
            'marker': None,
            'markersize': 0,
            'alpha': 1,
            'fill': False,
            'style': 'line'
        })
    elif style_in == "point":
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': '.',
            'markersize': 8,
            'alpha': 1,
            'fill': False,
            'style': 'point'
        })
    elif style_in in ('cross', 'Cross'):
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': 'x',
            'markersize': 8,
            'alpha': 1,
            'fill': False,
            'style': 'cross'
        })
    elif style_in == 'circle':
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': 'o',
            'markersize': 8,
            'alpha': 1,
            'fill': False,
            'style': 'circle'
        })
    elif style_in == 'plus':
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': '+',
            'markersize': 8,
            'alpha': 1,
            'fill': False,
            'style': 'plus'
        })
    elif style_in in ('asterisk', 'star'):
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': '*',
            'markersize': 12,
            'alpha': 1,
            'fill': False,
            'style': 'star'
        })
    elif style_in in ('pentagon', "pent"):
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': "p",
            'markersize': 10,
            'alpha': 1,
            'fill': False,
            'style': 'pentagon'
        })
    elif style_in in ('hex', 'hexagon'):
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': 'h',
            'markersize': 10,
            'alpha': 1,
            'fill': False,
            'style': 'hexagon'
        })
    elif style_in in ('triu', 'triangle-up'):
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': "^",
            'markersize': 10,
            'alpha': 1,
            'fill': False,
            'style': 'triangle-up'
        })
    elif style_in in ('trid', 'triangle-down'):
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': "v",
            'markersize': 10,
            'alpha': 1,
            'fill': False,
            'style': 'triangle-down'
        })
    elif style_in in ('tril','triangle-left') :
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': "<",
            'markersize': 10,
            'alpha': 1,
            'fill': False,
            'style': 'triangle-left'
        })
    elif style_in in ('trir', 'triangle-right') :
        trace_style.update({
            'linestyle': 'None',
            'linewidth': 1,
            'marker': ">",
            'markersize': 10,
            'alpha': 1,
            'fill': False,
            'style': 'triangle-right'
        })
    else:
        trace_style.update({
            'linestyle': '-',
            'linewidth': 1,
            'marker': None,
            'markersize': 0,
            'alpha': 1,
            'fill': False,
            'style': 'line'
        })

    return trace_style

def get_clean_color_style(color_in, style_in):
    
    clean_trace_dict = get_color_mpl(color_in)
    clean_trace_color = clean_trace_dict['color']

    # Set the style based on the style_in input
    if style_in in ('dot', 'dotted', 'Dot'):
        clean_trace_style = 'dotted'
    elif style_in in ('dash', 'Dash'):
        clean_trace_style = 'dash'
    elif style_in in ('fill', 'Fill'):
        clean_trace_style = 'fill'
    elif style_in in ('Line', 'line', 'lines'):
        clean_trace_style = 'line'
    elif style_in == 'point':
        clean_trace_style = 'point'
    elif style_in in ('cross', 'Cross'):
        clean_trace_style = 'cross'
    elif style_in == 'circle':
        clean_trace_style = 'circle'
    elif style_in == 'plus':
        clean_trace_style = 'cross'
    elif style_in in ('asterisk', 'star'):
        clean_trace_style = 'star'
    elif style_in in ('pentagon', 'pent'):
        clean_trace_style = 'pentagon'
    elif style_in in ('hex', 'hexagon'):
        clean_trace_style = 'hexagon'
    elif style_in in ('triu', 'triangle', 'triangle-up'):
        clean_trace_style = 'triangle-up'
    elif style_in == ('trid', 'triangle', 'triangle-down'):
        clean_trace_style = 'triangle-down'
    elif style_in == ('tril','triangle-left') :
        clean_trace_style = 'triangle-left'
    elif style_in == ('trir', 'triangle-right') :
        clean_trace_style = 'triangle-right'
    else:
        clean_trace_style = 'line'

    return clean_trace_color, clean_trace_style

In [7]:
#https://github.com/BrownParticleAstro/dmtools/blob/versioned/basecode/main/app/libraries/helper_operations/v0/helper_operations.py

In [8]:
unit_factors = {
    'eV': 1e0,
    'keV': 1e3,
    'MeV': 1e6,
    'GeV': 1e9,
    'TeV': 1e12
}

def normalize_unit(unit):
    # Remove '/c^2', '/c2', '/c²', etc. and whitespace
    unit = unit.strip()
    if '/c' in unit:
        unit = unit.split('/c')[0]
    return unit

def convert_mass_units(value, from_unit, to_unit):
    """
    Convert a mass value (or array) from one energy/c^2 unit to another.
    from_unit and to_unit can be like 'GeV', 'GeV/c^2', 'MeV/c2', etc.
    """
    from_unit_norm = normalize_unit(from_unit)
    to_unit_norm = normalize_unit(to_unit)
    #print("convert_mass_units : from ", from_unit_norm, " to ", to_unit_norm)
    if from_unit_norm not in unit_factors or to_unit_norm not in unit_factors:
        raise ValueError(f"Supported units: {list(unit_factors.keys())}")
    try:
        value_eV = value * unit_factors[from_unit_norm]
        result = value_eV / unit_factors[to_unit_norm]
    except:
        result = 0
    return result

allowed_units = ['eV', 'keV', 'MeV', 'GeV', 'TeV']

def get_x_label(selected_unit):
    selected_unit = normalize_unit(selected_unit)
    if selected_unit not in allowed_units:
        #raise ValueError(f"Unit must be one of: {allowed_units}")
        selected_unit = 'GeV'  # Default to GeV if invalid unit is provided
    return r"$\mathrm{WIMP\ Mass}\ [\mathrm{" + selected_unit + r"}/c^{2}]$"

In [9]:
data = get_current_plots_data()
plot_node = data.get('dmtools_current_plot').get('plot_node')
display_data = data.get('dmtools_current_plot').get('display_data')

In [10]:
plot_record = plot_node.get('plot_record')
plot_properties = plot_node.get('plot_properties')

In [11]:
#display_data

In [12]:
#data_dict = display_data[0]
#data_0 = data_dict.get('data')

In [13]:
for d in display_data:
    dd = d.get('data')
    #print(dd)
    dr = dd[0].get('data_record')
    drid = dr.get('id')
    print(drid)

241
210
235
254
218
217


In [14]:
# def set_plot_data(conn, plot_id, data_id=-1, schema="data", function_mode="local"):
def set_plot_data(conn='', plot_id=-1, data_id=-1, schema="data", function_mode="local"):
    px = 1 / 100
        
    # Create the figure
    
    ## this section is to create a suitable frame around the plot
    ## the plot was getting to tight to the edge of the container
    fig = plt.figure(figsize=(10, 10), linewidth=0, edgecolor='#D0D6DB', facecolor='#D0D6DB')
    gs = gridspec.GridSpec(64, 64)
    ax = fig.add_subplot(gs[2:63, 2:63])
    ax.set_facecolor('white')
    ax.tick_params(axis='both', labelsize=16)
    if function_mode == "local":

        data = get_current_plots_data()
        plot_data_cache = data.get('dmtools_current_plot')

    else:
        try:
            # Retrieve plot data
            '''
            plot_id_int = int(plot_id)
            plot_data, status_code = await get_plot_display_data_nodes_cached(conn, plot_id_int ,data_id = data_id, schema="data")
            #print("plot_operations plot_record >>>>", plot_data.get('plot_record'))
            plot_data_cache = await cache_get("dmtools_current_plot")
            #print("plot_operations - plot_record_cache >>>>", plot_data_cache.get('plot_record'))
    
            ## print("plot_operations - plot_data >>>>", plot_data)
            ## the following allows for a brand new plot with no data
            if status_code != 200:
                print("Failed to retrieve existing plot data.")
                plot_name, dmtools_edit_plot_url, dmtools_edit_legend_url = await create_empty_chart_and_legend_svgs()
                return plot_name, dmtools_edit_plot_url, dmtools_edit_legend_url
            '''
        except:
            a = 1

    # Process the plot nested json
    if plot_data_cache["plot_node"]:
        plot_node = plot_data_cache["plot_node"]
        plot_record =  plot_node["plot_record"]
        plot_properties = plot_node["plot_properties"]
        plot_name = plot_properties.get('name', 'Plot')
        #print("plot_operations - new plot_name >>>>", plot_name)
        plot_id = plot_record.get('id', -1)
    
    # Iterate over each child and their grandchildren to collect and plot data
    for display in plot_data_cache.get("display_data",[{}]):
        #print("plot operations - display loop >>>", display)
        display_record = display.get("display_record")
        display_properties = display.get('display_properties')
        ##print(f"\nDisplay Node - ID: {display_record.get('id')}, Type: {display_record.get('type')},\
        ##       Created At: {display_record.get('created')}, Updated At: {display_record.get('updated')}")

        for data in display.get("data"):
            data_record = data.get('data_record')
            data_properties = data.get('data_properties')
            ##print(f"    Data Node - ID: {data_record.get('id')}, Type: {data_record.get('type')}, \
            ##      Created At: {data_record.get('created')}, Updated At: {data_record.get('updated')}")

            # Rescaling factors with fallback values
            y_rescale = float(data_properties.get('yRescale', 1))
            x_rescale = float(data_properties.get('xRescale', 1))
            x_units = data_properties.get('xUnits', 1)
            
            trace_name = data_properties.get('label_short', 'Trace')
            list_data = data_properties.get('values',[[[0.0,0.0],[0.0,1.0]]])
            style = display_properties.get('style','line')
            color = display_properties.get('color','black')

            # Get line and fill plot styles
            # This uses a library to convert the dmtools style and color combination
            # into the required matplotlib configuration
            line_plot_kwargs = get_style_mpl(color, style)

            ## style and fill are not required for a line in matplotlib
            ## it fails if you call the plot function with them in
            line_plot_kwargs.pop('style', None)
            line_plot_kwargs.pop('fill', None)
            
            fill_plot_kwargs = get_style_mpl(color, style)
            ## a fill should have all the marker, line and fill definitions removed
            fill_plot_kwargs.pop('style', None)
            fill_plot_kwargs.pop('marker', None)
            fill_plot_kwargs.pop('markersize', None)
            fill_plot_kwargs.pop('linestyle', None)
            fill_plot_kwargs.pop('linewidth', None)
            fill_plot_kwargs.pop('fill', None)
            # print("fill_plot_kwargs >>>>", fill_plot_kwargs)

            # Plot each trace - a trace is a line
            for trace in list_data:
                               
                #try:
                x = [float(item[0]) * x_rescale for item in trace]

                ## this allows the plot to respond to changes to the x axis units
                selected_unit = plot_properties.get('xUnits', 'GeV/c^2')
                
                plot_type = plot_properties.get('plotType')

                if plot_type == "Cross Section vs WIMP Mass":
                    unit_x = [convert_mass_units(val, x_units, selected_unit) for val in x]
                    y = [float(item[1]) * y_rescale for item in trace]
                elif plot_type == "Cross Section / Mass [in GeV] vs Mass[GeV]":
                    # y should still use x in GeV for division
                    unit_x = [convert_mass_units(val, x_units, 'GeV') for val in x]
                    y = [(float(item[1]) * y_rescale) / xi for item, xi in zip(trace, unit_x)]
                else:
                    unit_x = [convert_mass_units(val, x_units, selected_unit) for val in x]
                    y = [float(item[1]) * y_rescale for item in trace]

                # Plot with line and optional fill style
                ax.plot(unit_x, y, **line_plot_kwargs)
                ##print('style >>', style)
                
                if style == 'fill':
                    ax.fill_between(x, y, **fill_plot_kwargs)
                #except:
                #    a = 1
                

    # Set scale, titles, and labels after all data is plotted
    ax.set_xscale('log')
    ax.set_yscale('log')

    # Manage the x and y axis range
    ymin_exp = plot_properties.get('yMin', '-42')
    ymin_exp = float(ymin_exp) if ymin_exp else -42  # Default to -42 if not set
    ymax_exp = plot_properties.get('yMax', '-42')
    ymax_exp = float(ymax_exp) if ymax_exp else -42  # Default to -42 if not set
    xmin = plot_properties.get('xMin', '0')
    xmin = float(xmin) if xmin else 0  # Default to 0 if not set
    xmax = plot_properties.get('xMax', '3')
    xmax = float(xmax) if xmax else 10000  # Default to 3 if not set

    # Convert to 10**exponent form
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([10**ymin_exp, 10**ymax_exp])

    #####
    ax.set_ylabel(r"$\mathrm{Cross\ Section}\ [cm^{2}]\ (\mathrm{normalized\ to\ nucleon})$", fontsize=18)
    
    # The x axis label is responsive to the Units selected
    #ax.set_xlabel(r"$\mathrm{WIMP\ Mass}\ [\mathrm{GeV}/c^{2}]$", fontsize=18)
    selected_unit = plot_properties.get('xUnits', 'GeV/c^2')
    ax.set_xlabel(get_x_label(selected_unit), fontsize=18)
    plot_title_default = r"$\mathrm{WIMP\ Mass\ vs\ Cross\ Section\ (Plot\ ID:\ " + str(plot_id) + r")}$"
    plot_title = plot_name + ' (' + str(plot_id) + ')' if plot_name else plot_title_default
    ax.set_title(plot_title, fontsize=18)

    # the following was required to ensure the scientific notations displayed correctly.
    # it relies on
    # import matplotlib as mpl
    # mpl.rcParams['axes.unicode_minus'] = False
    # from matplotlib.ticker import LogFormatterMathtext

    ax.xaxis.set_major_formatter(LogFormatterMathtext())
    ax.yaxis.set_major_formatter(LogFormatterMathtext())
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontname('DejaVu Sans')

    # Save to a BytesIO object as it is not possible to directly insert python
    # into an HTML page
    img = io.BytesIO()    
    fig.savefig(img, format='svg', dpi=100, bbox_inches='tight', pad_inches=0.25)
    img.seek(0)
    plt.close(fig)  # Close the figure to free memory
    
    # Encode to base64
    dmtools_plot_url = base64.b64encode(img.getvalue()).decode('utf8')
    
    # write the image to the cache - this is then retrieved by the javascript
    # and displayed in the rendered html

    # you will notice that none of the jinja2 templates require inputs
    # as all the data is shared via the cache
    # this was a design decision to simplify development
    
    ## cache_data = await cache_set("dmtools_plot_url", dmtools_plot_url)
    return plot_name, dmtools_plot_url

In [15]:
# Your base64 string
plot_name, dmtools_plot_url = set_plot_data()

from IPython.display import Image, display

try:
    base64.b64decode(dmtools_plot_url)
    print("Base64 is valid")
except Exception as e:
    print("Base64 is invalid:", e)


Base64 is valid


In [16]:
print('Base64 string length:', len(dmtools_plot_url))

Base64 string length: 101416


In [17]:
# Step 2: Display the base64 image in Jupyter
from IPython.display import HTML, display

display(HTML(f'<img src="data:image/svg+xml;base64,{dmtools_plot_url}" style="max-width:600px;"/>'))