In [23]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup
from torch.utils.data import DataLoader
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from tqdm import tqdm
from matplotlib import cm
import seaborn as sns
import matplotlib.lines as mlines
from sklearn.decomposition import PCA
from openTSNE import TSNE
from PIL import Image
import umap
import torch.nn.functional as F
from scipy.spatial.distance import cdist
import umap.plot
import math
from scipy.stats import norm
import joblib
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree

from core.final.dataset import PSMDataset
from core.final.model import GalSpecNet, MetaModel, Informer, AstroModel
from core.final.trainer import Trainer

In [24]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(config)
    elif config['mode'] == 'spectra':
        model = GalSpecNet(config)
    elif config['mode'] == 'meta':
        model = MetaModel(config)
    else:
        model = AstroModel(config)

    return model

def get_embs(dataloader):
    all_p_emb, all_s_emb, all_m_emb = [], [], []
    all_labels = []
    
    for photometry, photometry_mask, spectra, metadata, labels in tqdm(dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata = spectra.to(device), metadata.to(device)
    
        with torch.no_grad():
            p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
    
            all_p_emb.append(p_emb.cpu())
            all_s_emb.append(s_emb.cpu())
            all_m_emb.append(m_emb.cpu())
            all_labels.append(labels)
    
    all_p_emb = torch.vstack(all_p_emb)
    all_s_emb = torch.vstack(all_s_emb)
    all_m_emb = torch.vstack(all_m_emb)
    all_labels = torch.hstack(all_labels)

    return all_p_emb, all_s_emb, all_m_emb, all_labels

def get_centers(p_emb, s_emb, m_emb, train_labels):
    p_centers, s_centers, m_centers, all_centers = [], [], [], []

    for i in range(10):
        ind = train_labels == i
    
        p_center = p_emb[ind].mean(axis=0)
        p_center = p_center / p_center.norm()
        p_centers.append(p_center)
    
        s_center = s_emb[ind].mean(axis=0)
        s_center = s_center / s_center.norm()
        s_centers.append(s_center)
    
        m_center = m_emb[ind].mean(axis=0)
        m_center = m_center / m_center.norm()
        m_centers.append(m_center)

        all_emb = (p_emb + s_emb + m_emb) / 3
        all_center = all_emb[ind].mean(axis=0)
        all_center = all_center / all_center.norm()
        all_centers.append(all_center)
    
    p_centers = torch.stack(p_centers)
    s_centers = torch.stack(s_centers)
    m_centers = torch.stack(m_centers)
    all_centers = torch.stack(all_centers)

    return p_centers, s_centers, m_centers, all_centers

def plot_obj_umap(dataset, idx, embeddings, targets):
    el = dataset.df.iloc[idx]
    label = target2id[el['target']]
    period = el['org_period']
    l, b = id2lb[el['id']]
    
    photometry = dataset.get_vlc(el['name'])
    photometry = np.vstack(((photometry[:, 0] % period) / period, photometry[:, 1], photometry[:, 2])).T    # fold
    spectra = dataset.readLRSFits(os.path.join(dataset.lamost_spec_dir, el['spec_filename']))

    obj_embedding = embeddings[idx]
    
    fig, axs = plt.subplots(1, 3, figsize=(24, 8))
    axs[0].plot(photometry[:, 0], photometry[:, 1], '.', label='Photometry')
    axs[1].plot(spectra[:, 0], spectra[:, 1], label='Spectra')
    
    for class_name in sorted_classes:
        class_id = target2id[class_name]
        class_mask = targets == class_id
        axs[2].scatter(embeddings[class_mask, 0], embeddings[class_mask, 1], 
                       color=palette[class_id], label=class_name, alpha=1.0, s=marker_size)

    axs[2].scatter(obj_embedding[0], obj_embedding[1], color=palette[targets[idx]], edgecolors='black', alpha=1.0, s=5 * marker_size)
    handles_colors = [mlines.Line2D([], [], color=palette[target2id[class_name]], marker='o', linestyle='None', markersize=8, 
                      label=class_name) for class_name in sorted_classes]
    axs[2].legend(handles=handles_colors, loc='upper right', bbox_to_anchor=(1.15, 1), fontsize=10, title="Classes")
    
    
    plt.suptitle(f'period = {period}    label = {id2target[label]}     idx {idx}    L = {l}    B = {b}    id {el["id"]}')
    plt.tight_layout()
    plt.show()


def plot_one_embs_period(embeddings, periods):
    plt.figure(figsize=(10, 8))
    sc = plt.scatter(embeddings[:, 0], embeddings[:, 1], c=periods, cmap='viridis', s=marker_size, alpha=1.0)
    plt.colorbar(sc, label='Period')
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    plt.show()

In [25]:
org_train = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/spectra_and_v_train.csv')
org_val = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/spectra_and_v_val.csv')
org_test = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/spectra_and_v_test.csv')

org_train = org_train[['id', 'l', 'b']]
org_val = org_val[['id', 'l', 'b']]
org_test = org_test[['id', 'l', 'b']]

combined_df = pd.concat([org_train, org_val, org_test])
id2lb = combined_df.set_index('id')[['l', 'b']].T.to_dict('list')

run_id = 'MeriDK/AstroCLIPResults3/2wz4ysvn'
api = wandb.Api()
run = api.run(run_id)
config = run.config
config['use_wandb'] = False

model = get_model(config)
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

weights_path = os.path.join(config['weights_path'] + '-' + run_id.split('/')[-1], f'weights-best.pth')
model.load_state_dict(torch.load(weights_path, weights_only=False))

In [26]:
config['file'] = 'preprocessed_data/sub50_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=False)

val_dataset = PSMDataset(config, split='val')
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
val_p_emb, val_s_emb, val_m_emb, val_labels = get_embs(val_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers, all_centers = get_centers(train_p_emb, train_s_emb, train_m_emb, train_labels)

train_emb = (train_p_emb + train_s_emb + train_m_emb) / 3
val_emb = (val_p_emb + val_s_emb + val_m_emb) / 3
test_emb = (test_p_emb + test_s_emb + test_m_emb) / 3

umap = joblib.load('umap.pkl')
train_umap = umap.transform(train_emb)
val_umap = umap.transform(val_emb)
test_umap = umap.transform(test_emb)

In [27]:
id2target = test_dataset.id2target
target2id = test_dataset.target2id

palette = sns.color_palette("Spectral", len(id2target))
palette[3], palette[-4] = palette[-4], palette[3] 
palette[1], palette[-2] = palette[-2], palette[1] 

In [28]:
train_period = np.array(train_dataset.df['org_period'])

In [29]:
cls = joblib.load('rot-classes.pkl')
ids = train_dataset.df[train_dataset.df['target'] == 'ROT'][['id']]
ids['gaia_label'] = 'no-label'

for idx, label in cls.items():
    ids.loc[ids['id'] == 'EDR3 ' + idx, 'gaia_label'] = label

gaia_labels = np.array(ids['gaia_label'])
unique_labels = np.unique(gaia_labels)
num_labels = len(unique_labels)

gaia_id2target = {i: unique_labels[i] for i in range(len(unique_labels))}
gaia_target2id = {unique_labels[i]: i for i in range(len(unique_labels))}

rot_umap = train_umap[train_labels == target2id['ROT']]
rot_palette = sns.color_palette("Spectral", len(unique_labels))

In [30]:
num_cols = (num_labels + 1) // 2
fig, axes = plt.subplots(2, num_cols, figsize=(15, 8))
fig.suptitle('Scatter Plots for Each Class')

x_min, x_max = rot_umap[:, 0].min() - 1, rot_umap[:, 0].max() + 1
y_min, y_max = rot_umap[:, 1].min() - 1, rot_umap[:, 1].max() + 1
axes = axes.flatten()

for ax, class_name in zip(axes, unique_labels):
    ind = gaia_labels == class_name
    other_ind = ~ind

    ax.scatter(rot_umap[other_ind, 0], rot_umap[other_ind, 1], color='gray', alpha=0.2, s=marker_size)
    ax.scatter(rot_umap[ind, 0], rot_umap[ind, 1], color=rot_palette[gaia_target2id[class_name]], label=class_name, alpha=1.0, s=marker_size, edgecolors='black', linewidth=0.5)
    ax.set_title(f'Class: {class_name}')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.legend()

for ax in axes[len(unique_labels):]:
    ax.axis('off')

plt.tight_layout()
plt.show()

In [31]:
rot_periods = np.array(train_dataset.df[train_dataset.df['target'] == 'ROT'][['org_period']])

In [32]:
ind = gaia_labels == 'RS'
other_ind = ~ind

plt.scatter(rot_umap[other_ind, 0], rot_umap[other_ind, 1], color='gray', alpha=0.2, s=marker_size)
sc = plt.scatter(rot_umap[ind, 0], rot_umap[ind, 1], c=rot_periods[ind], cmap='viridis', alpha=1.0, s=marker_size)
plt.colorbar(sc, label='Period')

In [33]:
sc = plt.scatter(rot_umap[:, 0], rot_umap[:, 1], c=rot_periods, cmap='viridis', alpha=1.0, s=marker_size)
plt.colorbar(sc, label='Period')

In [34]:
train_labels == target2id['ROT']

In [35]:
rot_indices = np.where(train_labels == target2id['ROT'])[0]
rs_indices = rot_indices[np.where(gaia_labels == 'RS')[0]]

In [36]:
rs_indices.shape

In [37]:
rs_umap = train_umap[rs_indices]
rs_bin = np.where(rs_umap[:, 0] < 2.6, 1, 0)

In [38]:
plt.scatter(rs_umap[:, 0], rs_umap[:, 1], c=rs_bin)

In [39]:
umap = joblib.load('umap.pkl')
p_umap = umap.transform(train_p_emb[rs_indices])
s_umap = umap.transform(train_s_emb[rs_indices])
m_umap = umap.transform(train_m_emb[rs_indices])

In [40]:
plt.scatter(p_umap[:, 0], p_umap[:, 1], c=rs_bin)

In [41]:
plt.scatter(s_umap[:, 0], s_umap[:, 1], c=rs_bin)

In [42]:
plt.scatter(m_umap[:, 0], m_umap[:, 1], c=rs_bin)

In [43]:
rs_umap = train_umap[rs_indices]
rs_bin_period = np.where(train_dataset.df.loc[rs_indices, 'org_period'] > 10, 1, 0)

In [44]:
plt.scatter(rs_umap[:, 0], rs_umap[:, 1], c=rs_bin_period)

In [63]:
rot_indices = np.where(train_labels == target2id['ROT'])[0]
rot_umap = train_umap[rot_indices]
rot_bin = np.where(rot_umap[:, 0] < 2.5, 1, 0)
plt.scatter(rot_umap[:, 0], rot_umap[:, 1], c=rot_bin)

In [46]:
df = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/spectra_and_v_train.csv')

In [65]:
rot_df = df[df['id'].isin(train_dataset.df.iloc[rot_indices]['id'])]
rot_df = rot_df[train_dataset.meta_cols]

g_mag = rot_df['phot_g_mean_mag']  # Apparent magnitude in G band
bp_mag = rot_df['phot_bp_mean_mag']  # Apparent magnitude in BP band
rp_mag = rot_df['phot_rp_mean_mag']  # Apparent magnitude in RP band
parallax = rot_df['parallax']  # Parallax in milliarcseconds (mas)

# Convert parallax to distance in parsecs
distance_pc = 1000 / parallax  # Distance in parsecs

# Calculate absolute magnitude M_G
M_G = g_mag - 5 * np.log10(distance_pc) + 5

# Calculate color index G_BP - G_RP
color_index = bp_mag - rp_mag

plt.figure(figsize=(8, 6))
plt.scatter(color_index, M_G, c=rot_bin, marker='o', alpha=0.7)
plt.gca().invert_yaxis()  # Invert y-axis since brighter stars have lower magnitudes
plt.xlabel(r'Color Index $G_{BP} - G_{RP}$', fontsize=12)
plt.ylabel(r'Absolute Magnitude $M_G$', fontsize=12)
plt.title('Color-Magnitude Diagram (CMD)', fontsize=14)
plt.grid(True)
plt.tight_layout()
plt.show()

In [47]:
rs_df = df[df['id'].isin(train_dataset.df.iloc[rs_indices]['id'])]

In [48]:
rs_df = rs_df[train_dataset.meta_cols]

In [52]:
g_mag = rs_df['phot_g_mean_mag']  # Apparent magnitude in G band
bp_mag = rs_df['phot_bp_mean_mag']  # Apparent magnitude in BP band
rp_mag = rs_df['phot_rp_mean_mag']  # Apparent magnitude in RP band
parallax = rs_df['parallax']  # Parallax in milliarcseconds (mas)

# Convert parallax to distance in parsecs
distance_pc = 1000 / parallax  # Distance in parsecs

# Calculate absolute magnitude M_G
M_G = g_mag - 5 * np.log10(distance_pc) + 5

# Calculate color index G_BP - G_RP
color_index = bp_mag - rp_mag

In [60]:
plt.figure(figsize=(8, 6))
plt.scatter(color_index, M_G, c=rs_bin, marker='o', alpha=0.7)
plt.gca().invert_yaxis()  # Invert y-axis since brighter stars have lower magnitudes
plt.xlabel(r'Color Index $G_{BP} - G_{RP}$', fontsize=12)
plt.ylabel(r'Absolute Magnitude $M_G$', fontsize=12)
plt.title('Color-Magnitude Diagram (CMD)', fontsize=14)
plt.grid(True)
plt.tight_layout()
plt.show()

In [58]:
plt.scatter(color_index, M_G, c=rs_bin)

In [337]:
for el in train_dataset.df.loc[rs_indices[rs_bin == 0], 'name']:
    print(f"'{el[len('ASASSN-V'):]}',", end="")

In [220]:
clf = DecisionTreeClassifier(random_state=42)
clf.fit(rs_df, rs_bin)

In [221]:
y_pred = clf.predict(rs_df)

In [224]:
(y_pred == rs_bin).sum() / len(rs_bin)

In [228]:
plt.figure(figsize=(15, 10))
plot_tree(clf, filled=True, feature_names=rs_df.columns, class_names=["0", "1"])

In [432]:
"""
Ca II H & K Lines (3968.5 Å & 3933.7 Å):
Indicators of chromospheric activity.
Emission reversals or filled-in absorption cores suggest high activity.
H-alpha Line (6562.8 Å):
Emission or filled-in absorption indicative of active chromospheres.
Ca II Infrared Triplet (8498 Å, 8542 Å, 8662 Å
"""
def plot_spectra_rs(dataset, idx):
    el = dataset.df.iloc[idx]
    label = target2id[el['target']]
    period = el['org_period']
    l, b = id2lb[el['id']]

    fig, ax = plt.subplots(1, 1, figsize=(12, 5))
    spectra = dataset.readLRSFits(os.path.join(dataset.lamost_spec_dir, el['spec_filename']))
    ax.plot(spectra[:, 0], spectra[:, 1], label='Spectra')

    # ax.axvspan(3933.7-1, 3933.7+1, color='red', alpha=1.0, label='Ca II H & K Lines (3968.5 Å & 3933.7 Å)')
    # ax.axvspan(3968.5-1, 3968.5+1, color='red', alpha=1.0, label='Ca II H & K Lines (3968.5 Å & 3933.7 Å)')
    
    # ax.axvspan(4101.7-1, 4101.7+1, color='cyan', alpha=1.0, label='Hδ 4101.7 Å')
    # ax.axvspan(4340.5-1, 4340.5+1, color='cyan', alpha=1.0, label='Hγ 4340.5 Å')
    # ax.axvspan(4861.3-1, 4861.3+1, color='cyan', alpha=1.0, label='Hβ 4861.3 Å')
    # ax.axvspan(6562.8-1, 6562.8+1, color='cyan', alpha=1.0, label='Hα 6562.8 Å')

    # ax.axvspan(4383-1, 4383+1, color='magenta', alpha=1.0, label='Fe I 4383 Å')
    # ax.axvspan(4957-1, 4957+1, color='magenta', alpha=1.0, label='Fe I 4957 Å')
    # ax.axvspan(5167-1, 5167+1, color='magenta', alpha=1.0, label='Fe I 5167 Å')
    # ax.axvspan(5328-1, 5328+1, color='magenta', alpha=1.0, label='Fe I 5328 Å')

    # ax.axvspan(4736, 4738, color='purple', alpha=1.0, label='C-type C2 Swan band 4737 Å')
    # ax.axvspan(5164, 5166, color='purple', alpha=1.0, label='C-type C2 Swan band 5165 Å')

    # ax.axvspan(5896-1, 5896+1, color='brown', alpha=1.0, label='Na I D1 5896 Å')

    ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=10)
    ax.grid(True)

    plt.suptitle(f'period = {period}    label = {id2target[label]}     idx {idx}    L = {l}    B = {b}    id {el["id"]}')
    plt.tight_layout()
    plt.show()

In [413]:
rs_indices[rs_bin == 0][:5], rs_indices[rs_bin == 1][:5]

In [431]:
train_dataset.df.loc[rs_indices[rs_bin == 0], 'id']

In [433]:
for i in range(5):
    plot_spectra_rs(train_dataset, rs_indices[rs_bin == 0][i])

In [434]:
for i in range(5):
    plot_spectra_rs(train_dataset, rs_indices[rs_bin == 1][i])

In [149]:
ind = train_labels == target2id['M']
m_ind = np.arange(len(ind))[ind]
m_umap = train_umap[m_ind]
m_bin = np.where(m_umap[:, 1] > 2.3, 1, 0)

In [144]:
plt.scatter(m_umap[:, 0], m_umap[:, 1], c=m_bin)

In [330]:
len(m_ind[m_bin == 0])

In [322]:
for el in train_dataset.df.loc[m_ind[m_bin == 1], 'id']:
    print(el)

In [329]:
for el in train_dataset.df.loc[m_ind[m_bin == 0]]['name']:
    print(f"'{el.split('-')[-1][1:]}', ", end='')

In [None]:
J054342.21+290935.5

In [None]:
2137, 4590, 4605, 4610, 7444, 7758 = C

In [175]:
umap = joblib.load('umap.pkl')
p_umap = umap.transform(train_p_emb[m_ind])
s_umap = umap.transform(train_s_emb[m_ind])
m_umap = umap.transform(train_m_emb[m_ind])

In [176]:
plt.scatter(p_umap[:, 0], p_umap[:, 1], c=m_bin)

In [177]:
plt.scatter(s_umap[:, 0], s_umap[:, 1], c=m_bin)

In [178]:
plt.scatter(m_umap[:, 0], m_umap[:, 1], c=m_bin)

In [241]:
train_dataset.meta_cols

In [243]:
cols = train_dataset.meta_cols
# cols.remove('pmdec_error')
# cols.remove('e_w1_mag')
# cols.remove('e_w2_mag')
# cols.remove('e_w3_mag')

In [244]:
cols

In [245]:
df = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/spectra_and_v_train.csv')
m_df = df[df['id'].isin(train_dataset.df.iloc[m_ind]['id'])]
m_df = m_df[cols]

clf = DecisionTreeClassifier(random_state=42)
clf.fit(m_df, m_bin)

y_pred = clf.predict(m_df)
print((y_pred == m_bin).sum() / len(m_bin))

plt.figure(figsize=(15, 10))
plot_tree(clf, filled=True, feature_names=m_df.columns, class_names=["0", "1"])

In [None]:
"""
Spectral     𝐽−𝐾     	   BP-RP 
M-type	     1.2 – 1.6	   2.5 – 3.5
S-type	     1.6 – 2.2	   3.5 – 5.0
C-type	     2.0 – 3.0	   5.0 – 7.0

8 objs: j_k >= 2.088 => C-Type or S-Type
6 objs: 1.434 <= j_k <= 2.088 && bp_rp <= 2.995 => S-Type?
+2 objs
"""

In [None]:
for el in rs_indices[rs_bin == 0]:
    plot_obj_umap(train_dataset, el, train_umap, train_labels)

# MIRAS FROM THE SMALL CLUSTER

In [260]:
plot_spectra(train_dataset, 681)

In [367]:
def plot_spectra(dataset, idx, mtype='m'):
    el = dataset.df.iloc[idx]
    label = target2id[el['target']]
    period = el['org_period']
    l, b = id2lb[el['id']]

    fig, ax = plt.subplots(1, 1, figsize=(12, 5))

    if mtype == 'm':
         
        ax.axvspan(6562.8-1, 6562.8+1, color='cyan', alpha=1.0, label='M-type Hα 6562.8 Å')
        ax.axvspan(4861.3-1, 4861.3+1, color='cyan', alpha=1.0, label='M-type Hβ 4861.3 Å')
        ax.axvspan(4340.5-1, 4340.5+1, color='cyan', alpha=1.0, label='M-type Hγ 4340.5 Å')
        ax.axvspan(4101.7-1, 4101.7+1, color='cyan', alpha=1.0, label='M-type Hδ 4101.7 Å')
        
        # ax.axvspan(7149, 7151, color='purple', alpha=1.0, label='M-type TiO band 7150 Å')
        # ax.axvspan(7599, 7601, color='brown', alpha=1.0, label='M-type TiO band 7600 Å')
        # ax.axvspan(8499, 8501, color='pink', alpha=1.0, label='M-type TiO band 8500 Å')
        # ax.axvspan(4300, 4600, color='red', alpha=0.3, label='M-type TiO (4300-4600 Å)')
        # ax.axvspan(4800, 5200, color='orange', alpha=0.3, label='M-type TiO (4800-5200 Å)')
        # ax.axvspan(6150, 6300, color='yellow', alpha=0.3, label='M-type TiO (6150-6300 Å)')
        ax.axvspan(7050, 7300, color='green', alpha=0.3, label='M-type TiO (7050-7350 Å)')
        # ax.axvspan(7600, 8200, color='blue', alpha=0.3, label='M-type TiO (7600-8200 Å)')
    elif mtype == 's':
        ax.axvspan(4600, 4900, color='purple', alpha=0.3, label='S-type ZrO (4600-4900 Å)')
        ax.axvspan(5700, 6200, color='pink', alpha=0.3, label='S-type ZrO (5700-6200 Å)')
        ax.axvspan(7400, 7800, color='brown', alpha=0.3, label='S-type ZrO (7400-7800 Å)')
    else:
        ax.axvspan(4214, 4216, color='yellow', alpha=1.0, label='C-type CN band 4215 Å')
        ax.axvspan(4736, 4738, color='purple', alpha=1.0, label='C-type C2 Swan band 4737 Å')
        ax.axvspan(5164, 5166, color='brown', alpha=1.0, label='C-type C2 Swan band 5165 Å')
        # ax.axvspan(4300, 4700, color='cyan', alpha=0.3, label='C-type C2 (4300-4700 Å)')
        # ax.axvspan(5600, 6200, color='magenta', alpha=0.3, label='C-type CN (5600-6200 Å)')
        # ax.axvspan(7000, 9000, color='gray', alpha=0.3, label='C-type CN (7000-9000 Å)')

    spectra = dataset.readLRSFits(os.path.join(dataset.lamost_spec_dir, el['spec_filename']))
    ax.plot(spectra[:, 0], spectra[:, 1], label='Spectra')
    
    ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=10)
    ax.grid(True)

    plt.suptitle(f'period = {period}    label = {id2target[label]}     idx {idx}    L = {l}    B = {b}    id {el["id"]}')
    plt.tight_layout()
    plt.show()

In [66]:
len(m_ind[m_bin == 0])

In [368]:
for i in range(20):
    plot_spectra(train_dataset, m_ind[m_bin == 0][i], mtype='m')

In [348]:
for el in m_ind[m_bin == 1]:
    plot_spectra(train_dataset, el, mtype='c')

In [None]:
for el in m_ind[m_bin == 1]:
    plot_obj_umap(train_dataset, el, train_umap, train_labels)

# MIRAS FROM THE BIG CLUSTER

In [None]:
for el in m_ind[m_bin == 0]:
    plot_obj_umap(train_dataset, el, train_umap, train_labels)

# RS BIG CLUSTER

In [None]:
for el in rs_indices[rs_bin == 0]:
    plot_obj_umap(train_dataset, el, train_umap, train_labels)

# RS SMALL CLUSTER

In [None]:
for el in rs_indices[rs_bin == 1]:
    plot_obj_umap(train_dataset, el, train_umap, train_labels)