In [1]:
import sys
sys.path.insert(0, '..')

import os
import wandb
import random
import numpy as np
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, LinearLR
from torch.utils.data import DataLoader
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt

from core.final.dataset import PSMDataset
from core.final.model import Informer, GalSpecNet, MetaModel, AstroModel
from core.final.loss import CLIPLoss
from core.final.trainer import Trainer

In [2]:
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']
METADATA_COLS = [
    'mean_vmag',  'phot_g_mean_mag', 'e_phot_g_mean_mag', 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 'phot_rp_mean_mag',
    'e_phot_rp_mean_mag', 'bp_rp', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmra_error', 'pmdec',
    'pmdec_error', 'j_mag', 'e_j_mag', 'h_mag', 'e_h_mag', 'k_mag', 'e_k_mag', 'w1_mag', 'e_w1_mag',
    'w2_mag', 'e_w2_mag', 'w3_mag', 'w4_mag', 'j_k', 'w1_w2', 'w3_w4', 'pm', 'ruwe', 'l', 'b'
]
PHOTO_COLS = ['amplitude', 'period', 'lksl_statistic', 'rfr_score']

In [3]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(config)
    elif config['mode'] == 'spectra':
        model = GalSpecNet(config)
    elif config['mode'] == 'meta':
        model = MetaModel(config)
    else:
        model = AstroModel(config)

    if config['use_pretrain'] and config['use_pretrain'].startswith('CLIP'):
        weights = torch.load(config['use_pretrain'][4:], weights_only=True)

        if config['mode'] == 'photo':
            weights_prefix = 'photometry_encoder'
        elif config['mode'] == 'spectra':
            weights_prefix = 'spectra_encoder'
        elif config['mode'] == 'meta':
            weights_prefix = 'metadata_encoder'
        else:
            weights_prefix = None

        if weights_prefix:
            weights = {k[len(weights_prefix) + 1:]: v for k, v in weights.items() if k.startswith(weights_prefix)}

        model.load_state_dict(weights, strict=False)

    return model


def get_schedulers(config, optimizer):
    if config['scheduler'] == 'ExponentialLR':
        scheduler = ExponentialLR(optimizer, gamma=config['gamma'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=config['factor'], patience=config['patience'])
    else:
        raise NotImplementedError(f"Scheduler {config['scheduler']} not implemented")

    if config['warmup']:
        warmup_scheduler = LinearLR(optimizer, start_factor=1e-5, end_factor=1, total_iters=config['warmup_epochs'])
    else:
        warmup_scheduler = None

    return scheduler, warmup_scheduler


def set_random_seeds(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_config():
    config = {
        'project': 'AstroCLIPResults',
        'mode': 'clip',    # 'clip' 'photo' 'spectra' 'meta' 'all'
        'config_from': None,    # 'meridk/AstroCLIPResults/d2u52yml',
        'random_seed': 42,  # 42, 66, 0, 12, 123
        'use_wandb': True,
        'save_weights': False,
        'weights_path': f'/home/mariia/AstroML/weights/{datetime.now().strftime("%Y-%m-%d-%H-%M")}',
        # 'use_pretrain': 'CLIP/home/mariia/AstroML/weights/2024-08-14-14-05-zmjau1cu/weights-51.pth',
        'use_pretrain': None,
        'freeze': False,

        # Data General
        'data_root': '/home/mariia/AstroML/data/asassn/',
        'file': 'preprocessed_data/full_lb/spectra_and_v',
        'classes': CLASSES,
        'num_classes': len(CLASSES),
        'meta_cols': METADATA_COLS,
        'photo_cols': PHOTO_COLS,
        'min_samples': None,
        'max_samples': None,

        # Photometry
        'v_zip': 'asassnvarlc_vband_complete.zip',
        'v_prefix': 'vardb_files',
        'seq_len': 200,
        'phased': False,
        'p_aux': True,

        # Spectra
        'lamost_spec_dir': 'Spectra/v2',
        's_mad': True,     # if True use mad for norm else std
        's_aux': True,
        's_err': True,
        's_err_norm': True,

        # Photometry Model
        'p_enc_in': 3,
        'p_d_model': 128,
        'p_dropout': 0.2,
        'p_factor': 1,
        'p_output_attention': False,
        'p_n_heads': 4,
        'p_d_ff': 512,
        'p_activation': 'gelu',
        'p_e_layers': 8,

        # Spectra Model
        's_dropout': 0.2,
        's_conv_channels': [1, 64, 64, 32, 32],
        's_kernel_size': 3,
        's_mp_kernel_size': 4,

        # Metadata Model
        'm_hidden_dim': 512,
        'm_dropout': 0.2,

        # MultiModal Model
        'hidden_dim': 512,
        'fusion': 'avg',  # 'avg', 'concat'

        # Training
        'batch_size': 512,
        'lr': 0.001,
        'beta1': 0.9,
        'beta2': 0.999,
        'weight_decay': 0.01,
        'epochs': 100,
        'early_stopping_patience': 6,
        'scheduler': 'ReduceLROnPlateau',  # 'ExponentialLR', 'ReduceLROnPlateau'
        'gamma': 0.9,  # for ExponentialLR scheduler
        'factor': 0.3,  # for ReduceLROnPlateau scheduler
        'patience': 3,  # for ReduceLROnPlateau scheduler
        'warmup': True,
        'warmup_epochs': 10,
        'clip_grad': True,
        'clip_value': 5
    }

    if config['p_aux']:
        config['p_enc_in'] += len(config['photo_cols']) + 2     # +2 for mad and delta t

    if config['s_aux']:
        config['s_conv_channels'][0] += 1

    if config['s_err']:
        config['s_conv_channels'][0] += 1

    if config['config_from']:
        print(f"Copying params from the {config['config_from']} run")
        old_config = wandb.Api().run(config['config_from']).config

        for el in old_config:
            if el in [
                'p_dropout', 's_dropout', 'm_dropout', 'lr', 'beta1', 'weight_decay', 'epochs',
                'early_stopping_patience', 'factor', 'patience', 'warmup', 'warmup_epochs', 'clip_grad', 'clip_value',
                'use_pretrain', 'freeze', 'phased', 'p_aux', 's_aux', 's_err',
            ]:
                config[el] = old_config[el]

    config['clip_grad'] = True
    config['clip_value'] = 5

    return config

In [4]:
config = get_config()
set_random_seeds(config['random_seed'])

In [17]:
train_dataset = PSMDataset(config, split='train')
val_dataset = PSMDataset(config, split='val')

train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)

Using cuda


In [18]:
model = get_model(config)
model = model.to(device)

optimizer = Adam(model.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']),
                 weight_decay=config['weight_decay'])
scheduler, warmup_scheduler = get_schedulers(config, optimizer)
criterion = CLIPLoss() if config['mode'] == 'clip' else torch.nn.CrossEntropyLoss()

trainer = Trainer(model=model, optimizer=optimizer, scheduler=scheduler, warmup_scheduler=warmup_scheduler,
                  criterion=criterion, device=device, config=config)

In [19]:
def train_epoch(train_dataloader):
    losses = {'step_loss': [], 'loss_ps': [], 'loss_sm': [], 'loss_mp': []}
    
    trainer.model.train()
    trainer.zero_stats()

    for photometry, photometry_mask, spectra, metadata, labels in tqdm(train_dataloader):
        photometry, photometry_mask = photometry.to(trainer.device), photometry_mask.to(trainer.device)
        spectra, metadata, labels = spectra.to(trainer.device), metadata.to(trainer.device), labels.to(trainer.device)

        trainer.optimizer.zero_grad()
        loss, loss_ps, loss_sm, loss_mp = trainer.step_clip(photometry, photometry_mask, spectra, metadata)
        
        losses['step_loss'].append(loss.item())
        losses['loss_ps'].append(loss_ps.item())
        losses['loss_sm'].append(loss_sm.item())
        losses['loss_mp'].append(loss_mp.item())
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), trainer.clip_value)
        trainer.optimizer.step()

    return losses

In [20]:
all_losses = []

for i in range(5):
    losses = train_epoch(train_dataloader)
    all_losses.append(losses)

100%|██████████████████████████████████████████████████████████████████████████████████| 33/33 [02:05<00:00,  3.79s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 33/33 [01:53<00:00,  3.44s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 33/33 [01:58<00:00,  3.59s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 33/33 [02:02<00:00,  3.71s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 33/33 [01:56<00:00,  3.52s/it]


In [21]:
last_batch = None

for batch in tqdm(train_dataloader):
    last_batch = batch

100%|██████████████████████████████████████████████████████████████████████████████████| 33/33 [01:55<00:00,  3.50s/it]


In [22]:
last_batch[0].shape

torch.Size([512, 200, 9])

In [23]:
first_batch = next(iter(train_dataloader))
first_batch[0].shape

torch.Size([512, 200, 9])

In [None]:
[el['step_loss'] for el in all_losses]