In [53]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from datetime import datetime

from core.spectra.model import GalSpecNet
from core.metadata.model import MetaClassifier
from models.Informer import Informer

In [41]:
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']

In [55]:
def get_config(random_seed):
    config = {
        'project': 'AstroCLIPResults',
        'random_seed': random_seed,
        'use_wandb': True,
        'save_weights': True,
        'weights_path': f'/home/mariia/AstroML/weights/{datetime.now().strftime("%Y-%m-%d-%H-%M")}',
        'use_pretrain': 'CLIP/home/mariia/AstroML/weights/2024-07-25-14-18-es6hl0nb/weights-41.pth',
        # 'use_pretrain': None,

        # Data General
        'dataset': 'VPSMDatasetV2Spectra',     # 'VPSMDataset' or 'VPSMDatasetV2'
        'data_root': '/home/mariia/AstroML/data/asassn/',
        'file': 'preprocessed_data/full/spectra_and_v',
        'classes': CLASSES,
        'min_samples': None,
        'max_samples': None,
        'noise': False,  # for train data only
        'noise_coef': 2,

        # Photometry
        'v_zip': 'asassnvarlc_vband_complete.zip',
        'v_prefix': 'vardb_files',
        'seq_len': 200,
        'phased': True,
        'clip': False,
        'aux': True,

        # Spectra
        'lamost_spec_dir': 'Spectra/v2',
        'spectra_v_file': 'spectra_v_merged.csv',
        'z_corr': False,

        # Photometry Model
        'p_encoder_layers': 8,
        'p_d_model': 128,
        'p_dropout': 0.2,
        'p_feature_size': 3,
        'p_n_heads': 4,
        'p_d_ff': 512,

        # Spectra Model
        's_hidden_dim': 512,
        's_dropout': 0.2,

        # Metadata Model
        'm_hidden_dim': 512,
        'm_dropout': 0.2,

        # MultiModal Model
        'model': 'ModelV1',     # 'ModelV0' or 'ModelV1'
        'hidden_dim': 1024,
        'ps_coef': 1,
        'mp_coef': 1,
        'sm_coef': 1,

        # Training
        'batch_size': 64,
        'lr': 1e-3,
        'weight_decay': 1e-3,
        'epochs': 100,
        'optimizer': 'AdamW',
        'early_stopping_patience': 10,

        # Learning Rate Scheduler
        'factor': 0.3,
        'patience': 5,
    }

    if config['aux']:
        config['p_feature_size'] += 4

    return config

In [56]:
config = get_config(42)

In [58]:
model = MetaClassifier(hidden_dim=config['m_hidden_dim'], num_classes=10, dropout=config['m_dropout'])

In [59]:
model

MetaClassifier(
  (layer1): Sequential(
    (0): Linear(in_features=36, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
  )
  (layer2): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
  )
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

In [72]:
class MetaModel(nn.Module):
    def __init__(self, num_classes, input_dim=36, hidden_dim=512, dropout=0.5):
        super(MetaModel, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        x = self.fc(x)

        return x

In [73]:
model = MetaModel(10)

In [74]:
model

MetaModel(
  (model): Sequential(
    (0): Linear(in_features=36, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
  )
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

In [60]:
weights = torch.load('/home/mariia/AstroML/weights/2024-07-25-14-18-es6hl0nb/weights-41.pth', weights_only=True)

In [69]:
filtered_weights = {k[len('metadata_encoder.'):]: v for k, v in weights.items() if k.startswith('metadata_encoder')}

In [75]:
filtered_weights.keys()

dict_keys(['model.0.weight', 'model.0.bias', 'model.3.weight', 'model.3.bias'])

In [76]:
model.load_state_dict(filtered_weights, strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])