In [1]:
import sys
sys.path.insert(0, '..')

import chromadb
from chromadb import Client, Settings
import torch
import os
import wandb
import random
import numpy as np
import torch
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import joblib

from core.dataset import PSMDataset
from core.model import GalSpecNet, MetaModel, Informer, AstroM3

In [2]:
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']

METADATA_COLS = [
    'mean_vmag',  'phot_g_mean_mag', 'e_phot_g_mean_mag', 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 'phot_rp_mean_mag', 'e_phot_rp_mean_mag',
    'bp_rp', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmra_error', 'pmdec',
    'pmdec_error', 'j_mag', 'e_j_mag', 'h_mag', 'e_h_mag', 'k_mag', 'e_k_mag', 'w1_mag', 'e_w1_mag',
    'w2_mag', 'e_w2_mag', 'w3_mag', 'w4_mag', 'j_k', 'w1_w2', 'w3_w4', 'pm', 'ruwe', 'l', 'b'
]

PHOTO_COLS = ['amplitude', 'period', 'lksl_statistic', 'rfr_score']

METADATA_FUNC = {
    "abs": [
        "mean_vmag",
        "phot_g_mean_mag",
        "phot_bp_mean_mag",
        "phot_rp_mean_mag",
        "j_mag",
        "h_mag",
        "k_mag",
        "w1_mag",
        "w2_mag",
        "w3_mag",
        "w4_mag",
    ],
    "cos": ["l"],
    "sin": ["b"],
    "log": ["period"]
}

BOOK = ["id", "source_id", "asassn_name", "other_names", "raj2000", "dej2000", "epoch_hjd", "gdr2_id", "allwise_id", "apass_dr9_id", "edr3_source_id", "galex_id", "tic_id"]

LAMOST_DIR = '/home/mariia/AstroML/data/asassn/Spectra/v2'

In [3]:
run_id = 'MeriDK/AstroCLIPOptuna3/nksu4l24'
api = wandb.Api()
run = api.run(run_id)
config = run.config
config['use_wandb'] = False

model = AstroM3(config)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

weights_path = os.path.join(config['weights_path'] + '-' + run_id.split('/')[-1], f'weights-best.pth')
model.load_state_dict(torch.load(weights_path, weights_only=False))

<All keys matched successfully>

In [4]:
# this functions is from 028-meta-change.ipynb
def transform(df):
    for transformation_type, value in METADATA_FUNC.items():
        if transformation_type == "abs":
            for col in value:
                df[col] = (
                    df[col] - 10 + 5 * np.log10(np.where(df["parallax"] <= 0, 1, df["parallax"]))
                )
        elif transformation_type == "cos":
            for col in value:
                df[col] = np.cos(np.radians(df[col]))
        elif transformation_type == "sin":
            for col in value:
                df[col] = np.sin(np.radians(df[col]))
        elif transformation_type == "log":
            for col in value:
                df[col] = np.log10(df[col])

In [5]:
df_org = pd.read_csv('/home/mariia/AstroML/data/asassn/asassn_catalog_full.csv')
df_v = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/v.csv')
df_s = pd.read_csv('/home/mariia/AstroML/data/asassn/Spectra/lamost_spec.csv')
scaler = joblib.load('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/scaler.pkl')

# change the name column in df_org so it matches df_v
df_org['asassn_name'] = df_org['asassn_name'].apply(lambda x: x.replace(' ', ''))

# drop filename duplicates from spectra df
df_s = df_s.drop_duplicates(subset=['spec_filename'])

# save the period, it'll be used for plots
df_v.loc[:, 'org_period'] = df_v['period']

# do the metadata and photo aux transformations
transform(df_v)

# normalize the metadata and photo aux
cols = METADATA_COLS + PHOTO_COLS
df_v.loc[:, cols] = scaler.transform(df_v[cols])

  df_org = pd.read_csv('/home/mariia/AstroML/data/asassn/asassn_catalog_full.csv')
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [6]:
ds = PSMDataset(config)

In [7]:
chroma_client = chromadb.PersistentClient(path='/home/mariia/AstroML/notebooks/chromadb')
photo_collection = chroma_client.get_or_create_collection(name='photo')

In [None]:
for _, el in tqdm(df_v.iterrows(), total=len(df_v)):
    photometry = ds.get_vlc(el['name'])
    photometry, photometry_mask = ds.preprocess_lc(photometry, None, list(el[PHOTO_COLS]))
    
    photometry, photometry_mask = torch.from_numpy(photometry), torch.from_numpy(photometry_mask)
    photometry, photometry_mask = photometry.unsqueeze(0), photometry_mask.unsqueeze(0)
    photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
    
    with torch.no_grad():
        p_emb = model.photometry_encoder(photometry, photometry_mask)
        p_emb = model.photometry_proj(p_emb)
        p_emb = p_emb.squeeze().detach().cpu().numpy()
        
    photo_collection.add(
        embeddings=p_emb,
        metadatas={'name': el['name'], 'target': el['target']},
        ids=el['id']
    )

 14%|████████▉                                                        | 59652/435615 [1:24:26<8:06:10, 12.89it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

 44%|████████████████████████████                                    | 191332/435615 [4:29:23<5:15:19, 12.91it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

 73%|██████████████████████████████████████████████▍                 | 316218/435615 [7:42:45<3:16:52, 10.11it/s]