In [1]:
import matplotlib.pyplot as plt
import matplotlib
import logging as logger
import numpy as np
import os
import sys
from astropy.table import Table
from time import time
from collections import OrderedDict

sys.path.insert(0, "../")

from mcfa.mcfa import MCFA

%matplotlib inline

In [2]:
barklem_abundances = Table.read("../catalogs/barklem_t3.fits")


In [3]:
barklem_abundances

Name,El,logEpsX,e_logEpsX,e_logepsx_lc,__X_Fe_,e__X_Fe_,e__x_fe__lc,N,N3
Unnamed: 0_level_1,Unnamed: 1_level_1,[---],[---],[---],[---],[---],[---],Unnamed: 8_level_1,Unnamed: 9_level_1
str13,str2,float32,float32,float32,float32,float32,float32,int16,int16
CS 22175-007,Al,2.54,0.17,0.25,-1.12,0.18,0.26,1,1
CS 22175-007,Ba,-1.15,0.21,0.25,-0.47,0.18,0.26,2,1
CS 22175-007,C,5.77,0.19,0.26,0.19,0.18,0.27,1,1
CS 22175-007,Ca,3.87,0.12,0.2,0.31,0.15,0.25,9,5
CS 22175-007,Ce,,,,,,,10,0
CS 22175-007,Co,2.35,0.15,0.22,0.24,0.15,0.24,6,4
CS 22175-007,Cr,2.52,0.17,0.24,-0.35,0.16,0.24,3,3
CS 22175-007,Eu,,,,,,,4,0
CS 22175-007,Fe,4.69,0.13,0.18,0.0,0.14,0.22,55,35
...,...,...,...,...,...,...,...,...,...


In [4]:
# Solar abundances.
asplund_2009 = {
    "Pr": 0.72, 
    "Ni": 6.22, 
    "Gd": 1.07, 
    "Pd": 1.57, 
    "Pt": 1.62, 
    "Ru": 1.75, 
    "S": 7.12, 
    "Na": 6.24, 
    "Nb": 1.46, 
    "Nd": 1.42, 
    "Mg": 7.6, 
    "Li": 1.05, 
    "Pb": 1.75, 
    "Re": 0.26, 
    "Tl": 0.9, 
    "Tm": 0.1, 
    "Rb": 2.52, 
    "Ti": 4.95, 
    "As": 2.3, 
    "Te": 2.18, 
    "Rh": 0.91, 
    "Ta": -0.12, 
    "Be": 1.38, 
    "Xe": 2.24, 
    "Ba": 2.18, 
    "Tb": 0.3, 
    "H": 12.0, 
    "Yb": 0.84, 
    "Bi": 0.65, 
    "W": 0.85, 
    "Ar": 6.4, 
    "Fe": 7.5, 
    "Br": 2.54, 
    "Dy": 1.1, 
    "Hf": 0.85, 
    "Mo": 1.88, 
    "He": 10.93, 
    "Cl": 5.5, 
    "C": 8.43, 
    "B": 2.7, 
    "F": 4.56, 
    "I": 1.55, 
    "Sr": 2.87, 
    "K": 5.03, 
    "Mn": 5.43, 
    "O": 8.69, 
    "Ne": 7.93, 
    "P": 5.41, 
    "Si": 7.51, 
    "Th": 0.02, 
    "U": -0.54, 
    "Sn": 2.04, 
    "Sm": 0.96, 
    "V": 3.93, 
    "Y": 2.21, 
    "Sb": 1.01, 
    "N": 7.83, 
    "Os": 1.4, 
    "Se": 3.34, 
    "Sc": 3.15, 
    "Hg": 1.17, 
    "Zn": 4.56, 
    "La": 1.1, 
    "Ag": 0.94, 
    "Kr": 3.25, 
    "Co": 4.99, 
    "Ca": 6.34, 
    "Ir": 1.38, 
    "Eu": 0.52, 
    "Al": 6.45, 
    "Ce": 1.58, 
    "Cd": 1.71, 
    "Ho": 0.48, 
    "Ge": 3.65, 
    "Lu": 0.1, 
    "Au": 0.92, 
    "Zr": 2.58, 
    "Ga": 3.04, 
    "In": 0.8, 
    "Cs": 1.08, 
    "Cr": 5.64, 
    "Cu": 4.19, 
    "Er": 0.92
}
solar_abundance = asplund_2009

In [5]:
def atomic_number(element):
    periodic_table = """H                                                  He
                        Li Be                               B  C  N  O  F  Ne
                        Na Mg                               Al Si P  S  Cl Ar
                        K  Ca Sc Ti V  Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
                        Rb Sr Y  Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I  Xe
                        Cs Ba Lu Hf Ta W  Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
                        Fr Ra Lr Rf Db Sg Bh Hs Mt Ds Rg Cn UUt"""
    
    lanthanoids    =   "La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb"
    actinoids      =   "Ac Th Pa U  Np Pu Am Cm Bk Cf Es Fm Md No"
    
    periodic_table = periodic_table.replace(" Ba ", " Ba " + lanthanoids + " ") \
        .replace(" Ra ", " Ra " + actinoids + " ").split()
    
    return periodic_table.index(element.strip().title()) - 1


In [6]:
# Re-structure the data.
def restructure_barklem_data(data):
    
    parse_element = lambda x: x.strip().lower()
    
    unique_names = np.sort(np.unique(data["Name"]))
    unique_elements = np.sort(np.unique([parse_element(el) for el in data["El"]]))
    N, E = len(unique_names), len(unique_elements)
    
    restructured_data = OrderedDict(name=unique_names)
    for element in unique_elements:
        restructured_data.setdefault(f"{element}_h", np.nan * np.ones(N))
        restructured_data.setdefault(f"{element}_h_err_rel", np.nan * np.ones(N))
        restructured_data.setdefault(f"{element}_h_err_abs", np.nan * np.ones(N))

    for group in data.group_by("Name").groups:
        index = np.where(group["Name"][0] == unique_names)[0][0]
        for row in group:
            element = parse_element(row["El"])
            restructured_data[f"{element}_h"][index] = row["logEpsX"] - solar_abundance[row["El"].strip()]
            restructured_data[f"{element}_h_err_rel"][index] = row["e_logEpsX"]
            restructured_data[f"{element}_h_err_abs"][index] = row["e_logepsx_lc"]
    
    return Table(data=restructured_data)
        

In [7]:
data = restructure_barklem_data(barklem_abundances)

In [8]:
data

name,al_h,al_h_err_rel,al_h_err_abs,ba_h,ba_h_err_rel,ba_h_err_abs,c_h,c_h_err_rel,c_h_err_abs,ca_h,ca_h_err_rel,ca_h_err_abs,ce_h,ce_h_err_rel,ce_h_err_abs,co_h,co_h_err_rel,co_h_err_abs,cr_h,cr_h_err_rel,cr_h_err_abs,eu_h,eu_h_err_rel,eu_h_err_abs,fe_h,fe_h_err_rel,fe_h_err_abs,la_h,la_h_err_rel,la_h_err_abs,mg_h,mg_h_err_rel,mg_h_err_abs,mn_h,mn_h_err_rel,mn_h_err_abs,nd_h,nd_h_err_rel,nd_h_err_abs,ni_h,ni_h_err_rel,ni_h_err_abs,sc_h,sc_h_err_rel,sc_h_err_abs,sm_h,sm_h_err_rel,sm_h_err_abs,sr_h,sr_h_err_rel,sr_h_err_abs,ti_h,ti_h_err_rel,ti_h_err_abs,v_h,v_h_err_rel,v_h_err_abs,y_h,y_h_err_rel,y_h_err_abs,zn_h,zn_h_err_rel,zn_h_err_abs,zr_h,zr_h_err_rel,zr_h_err_abs
str13,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64
CS 22175-007,-3.91000003815,0.170000001788,0.25,-3.32999997616,0.209999993443,0.25,-2.66000001907,0.189999997616,0.259999990463,-2.47000011444,0.119999997318,0.20000000298,,,,-2.64000009537,0.15000000596,0.219999998808,-3.12000001907,0.170000001788,0.239999994636,,,,-2.80999994278,0.129999995232,0.180000007153,,,,-2.44999990463,0.119999997318,0.180000007153,-3.58999996662,0.170000001788,0.219999998808,,,,-2.51999995232,0.20000000298,0.300000011921,-2.65999999046,0.180000007153,0.230000004172,,,,-2.53999998689,0.20000000298,0.300000011921,-2.39000005722,0.170000001788,0.219999998808,,,,,,,,,,,,
CS 22186-023,-3.67000002861,0.15000000596,0.239999994636,-3.69999998093,0.189999997616,0.230000004172,-2.46000020981,0.189999997616,0.25,-2.57000001907,0.119999997318,0.20000000298,,,,-2.66000007629,0.159999996424,0.219999998808,-2.95999993324,0.159999996424,0.25,,,,-2.71999979019,0.129999995232,0.180000007153,,,,-2.61000022888,0.119999997318,0.180000007153,-3.16000001907,0.15000000596,0.209999993443,,,,-2.66000005722,0.170000001788,0.280000001192,-2.68999999166,0.170000001788,0.230000004172,,,,-2.76999999851,0.189999997616,0.300000011921,-2.50999994278,0.159999996424,0.209999993443,,,,-3.09999998569,0.180000007153,0.230000004172,,,,,,
CS 22186-025,-3.85000009537,0.209999993443,0.270000010729,-2.76999997377,0.239999994636,0.310000002384,-3.60000007629,0.20000000298,0.270000010729,-2.59,0.119999997318,0.20000000298,,,,-2.73000000954,0.15000000596,0.219999998808,-3.29000009537,0.159999996424,0.230000004172,,,,-2.86999988556,0.140000000596,0.189999997616,,,,-2.61000022888,0.119999997318,0.180000007153,-3.57999997616,0.159999996424,0.209999993443,,,,-3.10000011444,0.189999997616,0.25,-2.71000000238,0.180000007153,0.230000004172,,,,-3.07000000298,0.25,0.34999999404,-2.54999990463,0.15000000596,0.20000000298,,,,,,,,,,,,
CS 22886-042,-3.74999995232,0.170000001788,0.259999990463,-2.85000001669,0.239999994636,0.310000002384,-2.71000020981,0.189999997616,0.259999990463,-2.43999990463,0.119999997318,0.20000000298,,,,-2.65000008583,0.140000000596,0.209999993443,-2.9999998951,0.170000001788,0.259999990463,,,,-2.67999982834,0.119999997318,0.180000007153,,,,-2.46999988556,0.119999997318,0.180000007153,-3.22999995232,0.15000000596,0.209999993443,,,,-2.75999996185,0.189999997616,0.280000001192,-2.65,0.170000001788,0.230000004172,,,,-2.74000000477,0.20000000298,0.300000011921,-2.45,0.159999996424,0.209999993443,,,,-2.89000000715,0.180000007153,0.230000004172,,,,,,
CS 22892-052,-3.66000003815,0.159999996424,0.259999990463,-1.80999999523,0.159999996424,0.209999993443,-1.98999994278,0.189999997616,0.259999990463,-2.70999988556,0.119999997318,0.20000000298,,,,-2.99999999046,0.15000000596,0.219999998808,-3.39,0.170000001788,0.25,-1.41999997616,0.180000007153,0.230000004172,-2.94999980927,0.140000000596,0.189999997616,-1.85999999046,0.180000007153,0.230000004172,-2.85,0.129999995232,0.189999997616,-3.66000001907,0.119999997318,0.20000000298,-1.73000000238,0.189999997616,0.230000004172,-2.8299998951,0.20000000298,0.289999991655,-2.94999999702,0.170000001788,0.219999998808,,,,-2.24000000477,0.159999996424,0.230000004172,-2.83000011444,0.15000000596,0.20000000298,,,,-2.46999999046,0.180000007153,0.230000004172,,,,,,
CS 22945-028,-3.62000007629,0.180000007153,0.259999990463,-2.84000002623,0.230000004172,0.280000001192,-2.50000017166,0.189999997616,0.259999990463,-2.30999979019,0.109999999404,0.20000000298,,,,-2.5999998951,0.159999996424,0.219999998808,-2.87000001907,0.180000007153,0.259999990463,,,,-2.65999984741,0.129999995232,0.180000007153,,,,-2.36000022888,0.129999995232,0.180000007153,-3.24999993324,0.180000007153,0.230000004172,,,,-2.6100001049,0.219999998808,0.300000011921,-2.68000000119,0.180000007153,0.230000004172,,,,-2.49999999523,0.209999993443,0.300000011921,-2.38000006676,0.159999996424,0.209999993443,,,,,,,,,,,,
CS 22957-013,-3.5599998951,0.20000000298,0.259999990463,-3.23999994278,0.209999993443,0.270000010729,-2.58000009537,0.189999997616,0.259999990463,-2.42999991417,0.119999997318,0.20000000298,,,,-2.52999996185,0.159999996424,0.230000004172,-3.14999999046,0.180000007153,0.259999990463,,,,-2.63999986649,0.140000000596,0.189999997616,,,,-2.48999986649,0.109999999404,0.180000007153,-3.18,0.159999996424,0.219999998808,,,,-2.70000001907,0.209999993443,0.300000011921,-2.56000002623,0.180000007153,0.230000004172,,,,-2.81000000134,0.20000000298,0.310000002384,-2.41000003815,0.15000000596,0.209999993443,,,,,,,,,,,,
CS 22958-083,-3.65000004768,0.170000001788,0.259999990463,-3.82999997616,0.189999997616,0.239999994636,-2.19000022888,0.20000000298,0.259999990463,-2.4499998951,0.119999997318,0.20000000298,,,,-2.57999991417,0.159999996424,0.219999998808,-2.86000002861,0.170000001788,0.270000010729,,,,-2.78999996185,0.119999997318,0.180000007153,,,,-2.40999994278,0.119999997318,0.180000007153,-3.10000007629,0.159999996424,0.219999998808,,,,-2.47999999046,0.189999997616,0.289999991655,-2.74000000358,0.189999997616,0.239999994636,,,,-2.75000000268,0.209999993443,0.310000002384,-2.43000001907,0.159999996424,0.209999993443,,,,,,,,,,,,
CS 22960-010,-3.54999990463,0.140000000596,0.209999993443,,,,-1.87000005722,0.170000001788,0.239999994636,-2.36999997139,0.119999997318,0.20000000298,,,,,,,-2.7499998951,0.15000000596,0.20000000298,,,,-2.65000009537,0.119999997318,0.170000001788,,,,-2.42000017166,0.129999995232,0.189999997616,-2.96999996185,0.15000000596,0.20000000298,,,,-2.69000002861,0.159999996424,0.209999993443,-2.46000000238,0.180000007153,0.230000004172,,,,,,,-2.38000006676,0.159999996424,0.209999993443,,,,,,,,,,,,
CS 29491-069,-3.8099998951,0.140000000596,0.230000004172,-2.52000000358,0.219999998808,0.289999991655,-2.66999977112,0.189999997616,0.25,-2.50000008583,0.119999997318,0.20000000298,,,,-2.71000002861,0.159999996424,0.219999998808,-3.06000007629,0.159999996424,0.239999994636,-1.76000000954,0.180000007153,0.230000004172,-2.80999994278,0.129999995232,0.180000007153,,,,-2.54999980927,0.109999999404,0.180000007153,-3.42000000954,0.170000001788,0.219999998808,,,,-2.81999990463,0.180000007153,0.259999990463,-2.77999999523,0.170000001788,0.219999998808,,,,-2.63999999583,0.180000007153,0.270000010729,-2.52999992371,0.159999996424,0.209999993443,,,,-2.77999999285,0.180000007153,0.230000004172,,,,,,


In [9]:
data.dtype.names

('name',
 'al_h',
 'al_h_err_rel',
 'al_h_err_abs',
 'ba_h',
 'ba_h_err_rel',
 'ba_h_err_abs',
 'c_h',
 'c_h_err_rel',
 'c_h_err_abs',
 'ca_h',
 'ca_h_err_rel',
 'ca_h_err_abs',
 'ce_h',
 'ce_h_err_rel',
 'ce_h_err_abs',
 'co_h',
 'co_h_err_rel',
 'co_h_err_abs',
 'cr_h',
 'cr_h_err_rel',
 'cr_h_err_abs',
 'eu_h',
 'eu_h_err_rel',
 'eu_h_err_abs',
 'fe_h',
 'fe_h_err_rel',
 'fe_h_err_abs',
 'la_h',
 'la_h_err_rel',
 'la_h_err_abs',
 'mg_h',
 'mg_h_err_rel',
 'mg_h_err_abs',
 'mn_h',
 'mn_h_err_rel',
 'mn_h_err_abs',
 'nd_h',
 'nd_h_err_rel',
 'nd_h_err_abs',
 'ni_h',
 'ni_h_err_rel',
 'ni_h_err_abs',
 'sc_h',
 'sc_h_err_rel',
 'sc_h_err_abs',
 'sm_h',
 'sm_h_err_rel',
 'sm_h_err_abs',
 'sr_h',
 'sr_h_err_rel',
 'sr_h_err_abs',
 'ti_h',
 'ti_h_err_rel',
 'ti_h_err_abs',
 'v_h',
 'v_h_err_rel',
 'v_h_err_abs',
 'y_h',
 'y_h_err_rel',
 'y_h_err_abs',
 'zn_h',
 'zn_h_err_rel',
 'zn_h_err_abs',
 'zr_h',
 'zr_h_err_rel',
 'zr_h_err_abs')

In [10]:
elements = [ea for ea in data.dtype.names if ea.endswith("_h")]

In [11]:
elements

['al_h',
 'ba_h',
 'c_h',
 'ca_h',
 'ce_h',
 'co_h',
 'cr_h',
 'eu_h',
 'fe_h',
 'la_h',
 'mg_h',
 'mn_h',
 'nd_h',
 'ni_h',
 'sc_h',
 'sm_h',
 'sr_h',
 'ti_h',
 'v_h',
 'y_h',
 'zn_h',
 'zr_h']

In [12]:
X = np.array([data[el] for el in elements]).T

In [13]:
X.shape

(253, 22)

In [14]:
np.sum(np.all(np.isfinite(X), axis=1))

6

In [15]:
finites = dict()
for element in elements:
    finites[element] = np.sum(np.isfinite(data[element]))
finites

{'al_h': 239,
 'ba_h': 220,
 'c_h': 249,
 'ca_h': 253,
 'ce_h': 13,
 'co_h': 223,
 'cr_h': 248,
 'eu_h': 68,
 'fe_h': 253,
 'la_h': 33,
 'mg_h': 245,
 'mn_h': 237,
 'nd_h': 35,
 'ni_h': 247,
 'sc_h': 247,
 'sm_h': 9,
 'sr_h': 245,
 'ti_h': 250,
 'v_h': 47,
 'y_h': 154,
 'zn_h': 38,
 'zr_h': 48}

In [16]:
sorted_finites = sorted(finites.items(), key=lambda x:x[1])[::-1]
for k, v in sorted_finites:
    print(k, v)

fe_h 253
ca_h 253
ti_h 250
c_h 249
cr_h 248
sc_h 247
ni_h 247
sr_h 245
mg_h 245
al_h 239
mn_h 237
co_h 223
ba_h 220
y_h 154
eu_h 68
zr_h 48
v_h 47
zn_h 38
nd_h 35
la_h 33
ce_h 13
sm_h 9


In [17]:
D = 15

In [18]:
use_elements = [k for k, v in sorted_finites][:D]

In [19]:
use_elements

['fe_h',
 'ca_h',
 'ti_h',
 'c_h',
 'cr_h',
 'sc_h',
 'ni_h',
 'sr_h',
 'mg_h',
 'al_h',
 'mn_h',
 'co_h',
 'ba_h',
 'y_h',
 'eu_h']

In [20]:
idx = np.argsort(np.array([atomic_number(ln.split("_")[0]) for ln in use_elements]))
label_names = [use_elements[i] for i in idx]

X = np.array([data[ln] for ln in label_names]).T

In [21]:
keep = np.all(np.isfinite(X), axis=1)
X = X[keep]

In [24]:
X.shape
import pickle
with open("barklem.pkl", "wb") as fp:
    pickle.dump((X, label_names, keep), fp)

In [None]:
model = MCFA(n_components=1, n_latent_factors=3)
model.fit(X)

In [None]:

latex_labels = dict()
                               
A = model.theta_[model.parameter_names.index("A")]

fig, axes = plt.subplots(model.n_latent_factors, figsize=(10, 10))
for i, ax in enumerate(axes):
    ax.plot(A.T[i])
    ax.set_xticks(np.arange(D))
    ax.set_xticklabels([ln.split("_")[0] for ln in label_names])


In [None]:
N, D = X.shape

print(N,D)

# Do some grid search in J, K
Js = np.arange(1, 10 + 1).astype(int)
Ks = np.arange(1, 10 + 1).astype(int)

BICs = np.nan * np.ones((Js.size, Ks.size))
opt_times = np.nan * np.ones((Js.size, Ks.size))
log_likelihoods = np.nan * np.ones((Js.size, Ks.size))

write_results = False
results_path = "results/barklem-gridsearch-J{J}-K{K}.pkl"

results_folder = os.path.dirname(results_path)
if not os.path.exists(results_folder):
    os.mkdir(results_folder)

mcfa_kwds = dict(max_iter=1000, n_init=5, tol=1e-5, verbose=0,
                 random_seed=None)

c, C = (0, Js.size * Ks.size)

for k, K in enumerate(Ks):
    for j, J in enumerate(Js):
        
        c += 1
        print("{}/{}: J = {}, K = {}".format(c, C, J, K))
        if np.isfinite(BICs[j, k]):
            print("Skipping..")
            continue
        
        kwds = mcfa_kwds.copy()
        kwds.update(n_components=K, n_latent_factors=J)
        
        model = MCFA(**kwds)
        
        t_init = time()
        try:
            model.fit(X)
            
        except:
            logger.exception("Exception in fitting at J = {}, K = {}".format(J, K))
            opt_times[j, k] = np.nan
            BICs[j, k] = np.nan
            log_likelihoods[j, k] = np.nan
            continue
            
        t_opt = time() - t_init
        
        # Save results.
        opt_times[j, k] = t_opt
        BICs[j, k] = model.bic(X)
        log_likelihoods[j, k] = model.log_likelihood_                     
        
        if write_results:
            result = dict(kwds=kwds,
                          t_opt=t_opt,
                          bic=BICs[j, k],
                          tau=model.tau_,
                          theta=model.theta_,
                          n_iter=model.n_iter_, 
                          log_likelihood=model.log_likelihood_)
            
            path = results_path.format(J=J, K=K)
            with open(path, "wb") as fp:
                pickle.dump(result, fp, -1)

            print("Results written to {}".format(path))



In [None]:
def scatter_grid_search(Js, Ks, Zs, z_percentiles=None, cbar_label=None,
                        figsize=(10, 10), highlight_z_index=None, **kwargs):

    Jm, Km = np.meshgrid(Js, Ks)
    x, y = (Jm.flatten(), Km.flatten())
    z = Zs.T.flatten()

    kwds = dict(s=10, cmap="viridis")
    if z_percentiles is not None:
        vmin, vmax = np.nanpercentile(z, z_percentiles)
        kwds.update(vmin=vmin, vmax=vmax)
    
    kwds.update(kwargs)
    
    fig, ax = plt.subplots(figsize=figsize)
    scat = ax.scatter(x, y, c=z, **kwds)
    
    ax.set_xlabel(r"$J$")
    ax.set_ylabel(r"$K$")
    
    if cbar_label is not None:
        cbar = plt.colorbar(scat)
        cbar.set_label(cbar_label)
        
    ax.set_xticks(np.unique(x).astype(int))
    ax.set_xticklabels(np.unique(x).astype(str))
    
    if highlight_z_index is not None:
        ok = np.where(np.isfinite(z))[0]
        
        indices = ok[np.argsort(z[ok])][highlight_z_index]
        ax.scatter(x[indices], y[indices], zorder=-1,
                   s=100, lw=5, edgecolor="r", facecolor="none")
        
    
    fig.tight_layout()
    return fig


In [None]:
fig = scatter_grid_search(Js, Ks, log_likelihoods, s=50,
                          z_percentiles=[16, 84], cbar_label=r"$\mathcal{L}$",
                          highlight_z_index=-1)

In [None]:
fig = scatter_grid_search(Js, Ks, BICs, s=50,
                          z_percentiles=[1, 25], cbar_label=r"\textrm{BIC}",
                          highlight_z_index=0)

In [None]:
model = MCFA(n_components=2, n_latent_factors=7, **mcfa_kwds)
model.fit(X)

In [None]:
fig = model.plot_latent_space(X)

In [None]:

latex_labels = dict()
                               
A = model.theta_[model.parameter_names.index("A")]

fig, axes = plt.subplots(model.n_latent_factors, figsize=(10, 10))
for i, ax in enumerate(axes):
    ax.plot(A.T[i])
    ax.set_xticks(np.arange(D))
    ax.set_xticklabels([ln.split("_")[0] for ln in label_names])


In [None]:
psi_index = model.parameter_names.index("psi")
psi = model.theta_[psi_index]

fig, ax = plt.subplots(figsize=(10, 10))
ax.plot(psi)
ax.set_xticks(np.arange(D))
ax.set_xticklabels([ln.split("_")[0] for ln in label_names])

In [None]:
psi.shape

In [None]:
X_err_rel = np.array([data[f"{ln}_err_rel"] for ln in label_names]).T[keep]
X_err_abs = np.array([data[f"{ln}_err_abs"] for ln in label_names]).T[keep]

In [None]:
fig, ax = plt.subplots()
ax.plot(np.mean(X_err_rel, axis=0), c="tab:blue")
ax.plot(np.mean(X_err_abs, axis=0), c="tab:red")
ax.set_xticks(np.arange(D))
ax.set_xticklabels([ln.split("_")[0] for ln in label_names])

In [None]:
from mcfa import mpl_utils
fig = mpl_utils.corner_scatter(X, c=model.tau_.T[0], s=30, cmap="coolwarm",
                               label_names=[ln.split("_")[0] for ln in label_names])
