In [None]:
import pickle
import os
import numpy as np
import shutil
import pandas as pd
import seaborn as sns
from models.gaussian_mixture import remove_outliers, gaussian_mixture
from preprocessing.read_winter import load_winter
from graphs.mixture_fit import combined_fit_mixture
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.image as mpimg

# 1. Create combined graphs using exoplanets data

In [None]:
def combined_gaussians(show_graph=False, save_graph=False):
    sigma = 2
    fig_dir=None
    if save_graph:  
        if os.path.isdir(f"figures/combined_gaussians"):
            shutil.rmtree(f"figures/combined_gaussians")
        fig_dir = "figures/combined_gaussians"
    
        os.mkdir(fig_dir)     
    
    labels = pd.read_csv(f"data/crossmatch/dr3/gaiaedr3_star_labels.csv", dtype={"source_id": str, "Host": str}, nrows=1172)
    
    results = []
    n_files = 0
    for file_name in os.listdir("data/densities/dr3"):
        if "_0_200000" in file_name:
            with open(f"data/densities/dr3/{file_name}", "rb") as f:
                densities = pickle.load(f)
        else:
            continue
        
        densities = densities[:1047]

        stars = []
        for i in densities:
            # Compute log10 of the host density and expand dimensions for further use
            target = np.expand_dims(np.log10(i[1]), axis=0).T

            # Remove outliers outside sigma
            data = remove_outliers(i[4], sigma=sigma)

            # Apply gaussian mixture model to the data
            model, scores = gaussian_mixture(data, [target], components=2, scores_only=False)
            
            host_name = labels[labels["source_id"] == i[0]].values[0][1]
            
            # Draw best fit mixture
            if "5d" in file_name:
                host = ["_".join(file_name.rsplit("_", 5)[1:4]) + "_" + host_name, target]
            else:
                host = [file_name.rsplit("_", 4)[2] + "_" + host_name, target]

            stars.append((model, data, host))
        n_files += 1
        results.append(stars)

    for i in range(len(results[0])):
        models = []
        data = []
        hosts = []
        for j in range(n_files):
            models.append(results[j][i][0])
            data.append(results[j][i][1])
            hosts.append(results[j][i][2])
        
        print(hosts)
        combined_fit_mixture(models, data, hosts, n_files, fig_dir=fig_dir, show_graph=show_graph, save_graph=save_graph)

In [None]:
combined_gaussians(show_graph=True, save_graph=False)

# 2. Create Mass x Semi-Major axis graphs for exoplanets

In [None]:
def mass_sma(exoplanets_file, features_file):
    exoplanets_dir = "data/initial_datasets"
    class_dir = "data/classification/dr3"
    ex = pd.read_csv(os.path.join(exoplanets_dir, exoplanets_file), skiprows=28)
    df = pd.read_csv(os.path.join(class_dir, features_file), index_col=0)
    df["Host"] = df["Host"].astype(str)
    #winter = load_winter()

    ex["Host"] = ex["gaia_id"].str.replace("Gaia DR2 ", "")

    ex = ex[(ex["st_age"] > 1) & (ex["st_age"] < 4.5)]
    df = pd.merge(df, ex, on="Host")

    df = df[["Host", "gm_p_high", "pl_bmasse", "pl_orbsmax"]]
    df["pl_orbsmax"] = np.log10(df["pl_orbsmax"])  
    df["mass"] = np.log10(df["pl_bmasse"])
    
    high = df[df["gm_p_high"] > 0.84]
    low = df[df["gm_p_high"] < 0.16]
    
    fig, ax = plt.subplots(ncols=2, figsize=(10,5))
    sns.kdeplot(x=low["pl_orbsmax"], y=low["mass"], ax=ax[0], cmap="Blues", shade=True)
    sns.scatterplot(x=low["pl_orbsmax"], y=low["mass"], ax=ax[0], color="b")
    ax[0].set_xlim(-2,1)
    ax[0].set_ylim(-1, 3.5)
    ax[0].set_xticks([-2, -1, 0, 1])
    ax[0].set_yticks([-1, 0, 1, 2, 3, 3.5])
    
    sns.kdeplot(x=high["pl_orbsmax"], y=high["mass"], ax=ax[1], cmap="Reds", shade=True)
    sns.scatterplot(x=high["pl_orbsmax"], y=high["mass"], ax=ax[1], color="r")
    ax[1].set_xlim(-2,1)
    ax[1].set_ylim(-1, 3.5)
    ax[1].set_xticks([-2, -1, 0, 1])
    ax[1].set_yticks([-1, 0, 1, 2, 3, 3.5])

    return low

In [None]:
df = mass_sma("exoplanets.csv", "features_densities_gaiaedr3_5d_drop_rv_0_200000.csv")

In [None]:
df = mass_sma("exoplanets.csv", "features_densities_gaiaedr3_6d_0_200000.csv")

# 3. Comparison with A. Winter Phigh results

In [None]:
def hue(row):
    # MissmatchedPlow and Phigh
    if ((row["gm_p_high"] > 0.84) & (row["logPhigh"] < 0.16)) | ((row["gm_p_high"] < 0.16) & (row["logPhigh"] > 0.84)):
        val = 0
    # Missmatched Phigh for ambigous
    elif ((row["gm_p_high"] > 0.84) & (row["logPhigh"] < 0.84)) | ((row["gm_p_high"] < 0.84) & (row["logPhigh"] > 0.84)):
        val = 1   
    # Missmatched Plow for ambigous
    elif ((row["gm_p_high"] < 0.16) & (row["logPhigh"] > 0.16)) | ((row["gm_p_high"] > 0.16) & (row["logPhigh"] < 0.16)):
        val = 2
    # Ambigous group
    elif ((row["gm_p_high"] < 0.84) & (row["logPhigh"] > 0.16)) | ((row["gm_p_high"] > 0.16) & (row["logPhigh"] < 0.84)):
        val = 3
    # Matching values
    else:
        val = 4
    return val

In [None]:
df = pd.read_csv("data/classification/dr3/features_densities_gaiaedr3_6d_0_200000.csv", dtype={"source_id": str, "Host": str}, nrows=1172, index_col=0)
labels = pd.read_csv(f"data/crossmatch/dr3/gaiaedr3_star_labels.csv", dtype={"source_id": str, "Host": str}, nrows=1172)
winter = load_winter()

In [None]:
df["source_id"] = df["Host"]
df.drop("Host", axis=1, inplace=True)

In [None]:
df = pd.merge(df, labels, on="source_id")

In [None]:
df1 = pd.merge(df, winter, on="Host")
df1["logPhigh"] = 10 ** df1["logPhigh"]
df1 = df1[["Host", "gm_p_high", "logPhigh"]]
df1["hue"] = df1.apply(hue, axis=1)

In [None]:
colors = ["blue", "m", "red", "orange", "green"]
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 8), facecolor="w")

sns.scatterplot(x=df1["gm_p_high"], y=df1["logPhigh"], hue=df1["hue"], palette=colors, ax=ax)

# Set legend
blue_patch = mpatches.Patch(color=colors[0], label="Missmatched Plow for Phigh")
magenta_patch = mpatches.Patch(color=colors[1], label="Missmatched Phigh for ambigous")
red_patch = mpatches.Patch(color=colors[2], label="Missmatched Plow for ambigous")
orange_patch = mpatches.Patch(color=colors[3], label="Ambigous")
green_patch = mpatches.Patch(color=colors[4], label="Matching Phigh and Plow")
plt.legend(handles=[blue_patch, magenta_patch, red_patch, orange_patch, green_patch], bbox_to_anchor=(0., 1.02, 1., .102), 
           loc="lower left", ncol=3, mode="expand", borderaxespad=0.);

plt.xlabel("Density prediction (All neighbours)")
plt.ylabel("Winter density prediction (600 random neighbours)")

plt.savefig("report_images/winter_comparison.png")

## Gaussian mixture fit comparison graphs

In [None]:
# read images
img_A = mpimg.imread("figures/densities_gaiaedr3_6d_only-exoplanets/HD175541.png")
img_B = mpimg.imread("figures/densities_gaiaedr3_6d_only-exoplanets/WASP-12.png")
winter_A = mpimg.imread("winter_figures/HD175541.png")
winter_B = mpimg.imread("winter_figures/WASP-12.png")

# display images
fig, ax = plt.subplots(2,2, figsize=(12,10), facecolor="w")
ax[0][0].imshow(winter_A)
ax[0][0].axis("off")
ax[0][1].imshow(winter_B)
ax[0][1].axis("off")
ax[1][0].imshow(img_A)
ax[1][0].axis("off")
ax[1][1].imshow(img_B)
ax[1][1].axis("off")

fig.tight_layout()
plt.savefig("report_images/star_density_comparison.png")