In [7]:
# automatically reloads all modules before executing a new cell
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import joblib
from torch.utils.data import DataLoader

from core.dataset import PSMDataset
from core.trainer import Trainer
from core.model import GalSpecNet, MetaModel, Informer, AstroM3

In [11]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(
            classification=True if config['mode'] == 'photo' else False,
            num_classes=config['num_classes'],
            seq_len=config['seq_len'],
            enc_in=config['p_enc_in'],
            d_model=config['p_d_model'],
            dropout=config['p_dropout'],
            factor=config['p_factor'],
            output_attention=config['p_output_attention'],
            n_heads=config['p_n_heads'],
            d_ff=config['p_d_ff'],
            activation=config['p_activation'],
            e_layers=config['p_e_layers']
        )
    elif config['mode'] == 'spectra':
        model = GalSpecNet(
            classification=True if config['mode'] == 'spectra' else False,
            num_classes=config['num_classes'],
            dropout_rate=config['s_dropout'],
            conv_channels=config['s_conv_channels'],
            kernel_size=config['s_kernel_size'],
            mp_kernel_size=config['s_mp_kernel_size']
        )
    elif config['mode'] == 'meta':
        model = MetaModel(
            classification=True if config['mode'] == 'meta' else False,
            num_classes=config['num_classes'],
            input_dim=len(config['meta_cols']),
            hidden_dim=config['m_hidden_dim'],
            dropout=config['m_dropout']
        )
    else:
        model = AstroM3(
            classification=True if config['mode'] == 'all' else False,
            num_classes=config['num_classes'],
            hidden_dim=config['hidden_dim'],
            fusion=config['fusion'],

            # Photometry model params
            seq_len=config['seq_len'],
            p_enc_in=config['p_enc_in'],
            p_d_model=config['p_d_model'],
            p_dropout=config['p_dropout'],
            p_factor=config['p_factor'],
            p_output_attention=config['p_output_attention'],
            p_n_heads=config['p_n_heads'],
            p_d_ff=config['p_d_ff'],
            p_activation=config['p_activation'],
            p_e_layers=config['p_e_layers'],

            # Spectra model params
            s_dropout=config['s_dropout'],
            s_conv_channels=config['s_conv_channels'],
            s_kernel_size=config['s_kernel_size'],
            s_mp_kernel_size=config['s_mp_kernel_size'],

            # Metadata model params
            m_input_dim=len(config['meta_cols']),
            m_hidden_dim=config['m_hidden_dim'],
            m_dropout=config['m_dropout']
        )
    
    return model

In [33]:
api = wandb.Api()
run = api.run('meridk/AstroCLIPResults3/runs/3c2da15u')
config = run.config
config['use_wandb'] = False

In [34]:
model = get_model(config)
model.load_state_dict(torch.load(config['weights_path'] + '-' + run.id + '/weights-best.pth', weights_only=True))

<All keys matched successfully>

In [35]:
model.push_to_hub('MeriDK/AstroM3-CLIP-all')

model.safetensors:   0%|          | 0.00/66.1M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/MeriDK/AstroM3-CLIP-all/commit/a2374528a1f99f6efd73ccd50ee0f03b511eaf7a', commit_message='Push model using huggingface_hub.', commit_description='', oid='a2374528a1f99f6efd73ccd50ee0f03b511eaf7a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/MeriDK/AstroM3-CLIP-all', endpoint='https://huggingface.co', repo_type='model', repo_id='MeriDK/AstroM3-CLIP-all'), pr_revision=None, pr_num=None)