In [50]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import torch.nn.functional as F
from datetime import datetime
from torch.utils.data import DataLoader

from core.multimodal.dataset2 import VPSMDatasetV2
from core.multimodal.model import ModelV1

In [51]:
def get_config(random_seed):
    config = {
        'project': 'multimodal-contrastive',
        'random_seed': random_seed,
        'use_wandb': True,
        'save_weights': True,
        'weights_path': f'/home/mariia/AstroML/weights/{datetime.now().strftime("%Y-%m-%d-%H-%M")}',
        'use_pretrain': None,

        # Data General
        'dataset': 'VPSMDatasetV2',     # 'VPSMDataset' or 'VPSMDatasetV2'
        'data_root': '/home/mariia/AstroML/data/asassn/',
        'file': 'preprocessed_data/full/spectra_and_v',
        'classes': None,
        'min_samples': 200,
        'max_samples': None,

        # Photometry
        'v_zip': 'asassnvarlc_vband_complete.zip',
        'v_prefix': 'vardb_files',
        'seq_len': 200,
        'phased': True,
        'clip': False,
        'aux': True,

        # Spectra
        'lamost_spec_dir': 'Spectra/v2',
        'spectra_v_file': 'spectra_v_merged.csv',
        'z_corr': False,

        # Photometry Model
        'p_encoder_layers': 8,
        'p_d_model': 128,
        'p_dropout': 0,
        'p_feature_size': 3,
        'p_n_heads': 4,
        'p_d_ff': 512,

        # Spectra Model
        's_hidden_dim': 512,
        's_dropout': 0,

        # Metadata Model
        'm_hidden_dim': 512,
        'm_dropout': 0,

        # MultiModal Model
        'model': 'ModelV1',     # 'ModelV0' or 'ModelV1'
        'hidden_dim': 256,

        # Training
        'batch_size': 32,
        'lr': 1e-3,
        'weight_decay': 0,
        'epochs': 50,
        'optimizer': 'AdamW',
        'early_stopping_patience': 100,

        # Learning Rate Scheduler
        'factor': 0.3,
        'patience': 50,
    }

    config['p_feature_size'] += 4

    return config

In [9]:
config = get_config(42)

In [2]:
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']

In [3]:
train_dataset = VPSMDatasetV2(split='train', classes=CLASSES, seq_len=200, phased=True, clip=False, aux=True, z_corr=False, random_seed=42)
val_dataset = VPSMDatasetV2(split='val', classes=CLASSES, seq_len=200, phased=True, clip=False, aux=True, z_corr=False, random_seed=42)
test_dataset = VPSMDatasetV2(split='test', classes=CLASSES, seq_len=200, phased=True, clip=False, aux=True, z_corr=False, random_seed=42)

In [4]:
len(train_dataset), len(val_dataset), len(test_dataset)

(17282, 2175, 2247)

In [15]:
model = ModelV1(config)

In [13]:
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=False)

In [14]:
photometry, photometry_mask, spectra, metadata, y = next(iter(train_dataloader))

In [17]:
photometry.shape, photometry_mask.shape, spectra.shape, metadata.shape, y.shape

(torch.Size([32, 200, 7]),
 torch.Size([32, 200]),
 torch.Size([32, 1, 2575]),
 torch.Size([32, 36]),
 torch.Size([32]))

In [16]:
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)

In [18]:
p_emb.shape, s_emb.shape, m_emb.shape

(torch.Size([32, 256]), torch.Size([32, 256]), torch.Size([32, 256]))

In [22]:
p_emb = p_emb / p_emb.norm(dim=-1, keepdim=True)
s_emb = s_emb / s_emb.norm(dim=-1, keepdim=True)
m_emb = m_emb / m_emb.norm(dim=-1, keepdim=True)

In [38]:
torch.clamp(model.logit_scale_ps.exp(), min=1, max=100)

tensor(100., grad_fn=<ClampBackward1>)

In [43]:
(model.logit_scale_ps * p_emb @ s_emb.T).shape

torch.Size([32, 32])

In [45]:
logit_scale_ps = torch.clamp(model.logit_scale_ps.exp(), min=1, max=100)
logit_scale_sm = torch.clamp(model.logit_scale_sm.exp(), min=1, max=100)
logit_scale_mp = torch.clamp(model.logit_scale_mp.exp(), min=1, max=100)

In [46]:
logits_ps = logit_scale_ps * p_emb @ s_emb.T
logits_sm = logit_scale_sm * s_emb @ m_emb.T
logits_mp = logit_scale_mp * m_emb @ p_emb.T

In [47]:
labels = torch.arange(logits_ps.shape[0], dtype=torch.int64, device=logits_ps.device)

In [48]:
labels

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

In [54]:
F.cross_entropy(logits_ps, labels)

tensor(4.0040, grad_fn=<NllLossBackward0>)

In [55]:
F.cross_entropy(logits_ps.transpose(-1, -2), labels)

tensor(5.3719, grad_fn=<NllLossBackward0>)

In [60]:
probabilities = torch.nn.functional.softmax(logits_ps, dim=1)
_, predicted_labels = torch.max(probabilities, dim=1)

In [61]:
predicted_labels

tensor([31,  5, 31, 21,  3, 28, 31,  5, 28,  5, 31, 28,  3, 27,  3,  5,  7,  5,
         3, 31,  3, 27,  3, 21, 27,  3,  4, 31,  3,  7, 31,  5])