# Explore the catalog using interactive plotly graphs

In [15]:
import json
import numpy as np
import pandas as pd

import os
from io import StringIO

import matplotlib.pyplot as plt
import plotly.express as px


file_path = "../data/post_mt_systems.json"  # or dummy.json

In [16]:
# === Load JSON ===
def read_json_file(file_path):
    try:
        with open(file_path, "r") as f:
            data = json.load(f)
        return data
    except Exception as e:
        print(f"Error reading file: {e}")
        return None

# ==  Function to easily extract some of the cols ==    
def extract_array(data, key):
    """Extracts a given key from each system and returns a NumPy array."""
    try:
        values = [entry.get(key, np.nan) for entry in data]
        return np.array(values)
    except Exception as e:
        print(f"Error extracting {key}: {e}")
        return None

def extract_multiple(data, keys):
    return {key: extract_array(data, key) for key in keys}


In [17]:
# === Interactive Plotting with Plotly ===
def plotly_vars(col1, col2, catalog_data, data=None, logx=False, logy=False):
    """
    Plot col2 vs. col1 from catalog_data (e.g. output of extract_multiple),
    optionally using metadata (from JSON) for hover and coloring.
    
    Parameters:
        col1, col2: str
            Keys to plot (e.g. 'M1', 'q')
        catalog_data: dict of str -> np.ndarray
            Typically from extract_multiple()
        data: list of dict
            Raw JSON data for metadata (optional but enables color, hover)
        logx, logy: bool
            Whether to use logarithmic axes
    """
    # Extract central values and uncertainties
    x = catalog_data[col1][:, 1]
    xerr_lo = x - catalog_data[col1][:, 0]
    xerr_hi = catalog_data[col1][:, 2] - x

    y = catalog_data[col2][:, 1]
    yerr_lo = y - catalog_data[col2][:, 0]
    yerr_hi = catalog_data[col2][:, 2] - y

    # Default metadata
    N = len(x)
    system_name = [""] * N
    type1 = [""] * N
    type2 = ["Unknown"] * N
    marker_size = [4] * N

    if data is not None:
        system_name = [entry.get("System Name", "") for entry in data]
        type1 = [entry.get("Type1", "") for entry in data]
        type2 = [entry.get("Type2", "Unknown") for entry in data]
        marker_size = [1 if t2 == "WD" else 5 for t2 in type2]

    fig = px.scatter(
        x=x,
        y=y,
        color=type2,
        hover_data={"System Name": system_name, "Type1": type1},
        error_x=xerr_hi,
        error_x_minus=xerr_lo,
        error_y=yerr_hi,
        error_y_minus=yerr_lo,
        size=marker_size,
        size_max=5,
        labels={"x": col1, "y": col2}
    )

    fig.update_layout(
        width=900,
        height=600,
        xaxis_title=col1,
        yaxis_title=col2,
        xaxis_title_font=dict(size=18),
        yaxis_title_font=dict(size=18),
        legend=dict(font=dict(size=16)),
    )

    fig.update_yaxes(tickfont=dict(size=14))
    fig.update_xaxes(tickfont=dict(size=14))

    if logx:
        fig.update_xaxes(type="log")
    if logy:
        fig.update_yaxes(type="log")

    return fig


In [18]:
# Load json data 
data = read_json_file("../data/post_mt_systems.json")
# show_available_keys
print(list(data[0].keys()))

# extract variables
catalog_data = extract_multiple(data, ["M1", "q", "Period"])


['System Name', 'Type1', 'Type2', 'Detection Method', 'Reference', 'Notes', 'RA', 'Dec', 'Period', 'Eccentricity', 'M1', 'M1_sin3i', 'M2', 'M2_sin3i', 'q', 'Mass Function']


In [23]:
# Plot q vs M1
plotly_vars("q", "Period", catalog_data, data=data, logy=True)
