In [1]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from datetime import datetime
from tqdm import tqdm
import umap
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import seaborn as sns

from core.multimodal.dataset2 import VPSMDatasetV2
from core.multimodal.model import ModelV1

In [2]:
def get_config(random_seed):
    config = {
        'project': 'multimodal-contrastive',
        'random_seed': random_seed,
        'use_wandb': True,
        'save_weights': True,
        'weights_path': f'/home/mariia/AstroML/weights/{datetime.now().strftime("%Y-%m-%d-%H-%M")}',
        'use_pretrain': None,

        # Data General
        'dataset': 'VPSMDatasetV2',     # 'VPSMDataset' or 'VPSMDatasetV2'
        'data_root': '/home/mariia/AstroML/data/asassn/',
        'file': 'preprocessed_data/full/spectra_and_v',
        'classes': None,
        'min_samples': 200,
        'max_samples': None,

        # Photometry
        'v_zip': 'asassnvarlc_vband_complete.zip',
        'v_prefix': 'vardb_files',
        'seq_len': 200,
        'phased': True,
        'clip': False,
        'aux': True,

        # Spectra
        'lamost_spec_dir': 'Spectra/v2',
        'spectra_v_file': 'spectra_v_merged.csv',
        'z_corr': False,

        # Photometry Model
        'p_encoder_layers': 8,
        'p_d_model': 128,
        'p_dropout': 0,
        'p_feature_size': 3,
        'p_n_heads': 4,
        'p_d_ff': 512,

        # Spectra Model
        's_hidden_dim': 512,
        's_dropout': 0,

        # Metadata Model
        'm_hidden_dim': 512,
        'm_dropout': 0,

        # MultiModal Model
        'model': 'ModelV1',     # 'ModelV0' or 'ModelV1'
        'hidden_dim': 256,

        # Training
        'batch_size': 32,
        'lr': 1e-3,
        'weight_decay': 0,
        'epochs': 50,
        'optimizer': 'AdamW',
        'early_stopping_patience': 100,

        # Learning Rate Scheduler
        'factor': 0.3,
        'patience': 50,
    }

    config['p_feature_size'] += 4

    return config

In [3]:
config = get_config(42)
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train_dataset = VPSMDatasetV2(split='train', classes=CLASSES, seq_len=200, phased=True, clip=False, aux=True, z_corr=False, random_seed=42)
val_dataset = VPSMDatasetV2(split='val', classes=CLASSES, seq_len=200, phased=True, clip=False, aux=True, z_corr=False, random_seed=42)
test_dataset = VPSMDatasetV2(split='test', classes=CLASSES, seq_len=200, phased=True, clip=False, aux=True, z_corr=False, random_seed=42)

In [43]:
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

In [19]:
class ClassificationModelv0(nn.Module):
    def __init__(self, config, num_classes, freeze=False, weights_path=None):
        super(ClassificationModelv0, self).__init__()

        self.encoder = ModelV1(config)
        self.encoder = self.encoder.to(device)

        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False

        if weights_path:
            self.encoder.load_state_dict(torch.load(weights_path))

        self.fc = nn.Linear(768, num_classes)

    def forward(self, photometry, photometry_mask, spectra, metadata):
        p_emb, s_emb, m_emb = self.encoder.get_embeddings(photometry, photometry_mask, spectra, metadata)
        emb = torch.cat((p_emb, s_emb, m_emb), dim=1)
        logits = self.fc(emb)

        return logits

In [24]:
class ClassificationModelv1(nn.Module):
    def __init__(self, config, num_classes, freeze=False, weights_path=None):
        super(ClassificationModelv1, self).__init__()

        self.encoder = ModelV1(config)
        self.encoder = self.encoder.to(device)

        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False

        if weights_path:
            self.encoder.load_state_dict(torch.load(weights_path))

        self.mlp = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, photometry, photometry_mask, spectra, metadata):
        p_emb, s_emb, m_emb = self.encoder.get_embeddings(photometry, photometry_mask, spectra, metadata)
        emb = torch.cat((p_emb, s_emb, m_emb), dim=1)
        logits = self.mlp(emb)

        return logits

In [77]:
class ClassificationModelv2(nn.Module):
    def __init__(self, config, num_classes, freeze=False, weights_path=None):
        super(ClassificationModelv2, self).__init__()

        self.encoder = ModelV1(config)
        self.encoder = self.encoder.to(device)

        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False

        if weights_path:
            self.encoder.load_state_dict(torch.load(weights_path))

        self.mlp = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, photometry, photometry_mask, spectra, metadata):
        p_emb, s_emb, m_emb = self.encoder.get_embeddings(photometry, photometry_mask, spectra, metadata)
        emb = torch.cat((p_emb, s_emb, m_emb), dim=1)
        logits = self.mlp(emb)

        return logits

In [None]:
def train_epoch(model):
    model.train()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    for photometry, photometry_mask, spectra, metadata, y in tqdm(train_dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata, y = spectra.to(device), metadata.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(photometry, photometry_mask, spectra, metadata)
        loss = criterion(logits, y)
        total_loss.append(loss.item())

        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == y).sum().item()

        total_correct_predictions += correct_predictions
        total_predictions += y.size(0)

        loss.backward()
        optimizer.step()
        
    print(f'Train Loss: {sum(total_loss) / len(total_loss)} Acc: {total_correct_predictions / total_predictions}')

In [None]:
def val_epoch(model):
    model.eval()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    for photometry, photometry_mask, spectra, metadata, y in tqdm(val_dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata, y = spectra.to(device), metadata.to(device), y.to(device)

        with torch.no_grad():
            logits = model(photometry, photometry_mask, spectra, metadata)
            loss = criterion(logits, y)
            total_loss.append(loss.item())

        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == y).sum().item()

        total_correct_predictions += correct_predictions
        total_predictions += y.size(0)

    print(f'Val Loss: {sum(total_loss) / len(total_loss)} Acc: {total_correct_predictions / total_predictions}')

In [68]:
def evaluate(model, val_dataloader):
    model.eval()
    all_true_labels = []
    all_predicted_labels = []

    for photometry, photometry_mask, spectra, metadata, y in tqdm(val_dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata = spectra.to(device), metadata.to(device)

        with torch.no_grad():
            logits = model(photometry, photometry_mask, spectra, metadata)
            probabilities = torch.nn.functional.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probabilities, dim=1)

            all_true_labels.extend(y.numpy())
            all_predicted_labels.extend(predicted_labels.cpu().numpy())

    acc = sum([1 if all_true_labels[i] == all_predicted_labels[i] else 0 for i in range(len(all_true_labels))]) / len(all_true_labels)
    print(f'Total Accuracy: {round(acc * 100, 2)}%')

    conf_matrix = confusion_matrix(all_true_labels, all_predicted_labels)
    conf_matrix_percent = 100 * conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis]

    avg_acc = sum([conf_matrix_percent[i][i] for i in range(len(conf_matrix_percent))]) / len(conf_matrix_percent)
    print(f'Avarage Accuracy: {round(avg_acc, 2)}%')
    
    labels = [val_dataset.id2target[i] for i in range(len(conf_matrix))]
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 7))
    
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    axes[0].set_title('Confusion Matrix - Absolute Values')

    sns.heatmap(conf_matrix_percent, annot=True, fmt='.0f', cmap='Blues', xticklabels=labels, yticklabels=labels,
                ax=axes[1])
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('True')
    axes[1].set_title('Confusion Matrix - Percentages')

In [14]:
model2 = ClassificationModelv0(config, train_dataset.num_classes, freeze=False, weights_path=None)
model2 = model2.to(device)
optimizer = Adam(model2.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model2)
    val_epoch(model2)

In [21]:
model3 = ClassificationModelv0(config, train_dataset.num_classes, freeze=False, weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model3 = model3.to(device)
optimizer = Adam(model3.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model3)
    val_epoch(model3)

In [22]:
model4 = ClassificationModelv0(config, train_dataset.num_classes, freeze=True, weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model4 = model4.to(device)
optimizer = Adam(model4.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model4)
    val_epoch(model4)

In [25]:
model5 = ClassificationModelv1(config, train_dataset.num_classes, freeze=False, weights_path=None)
model5 = model5.to(device)
optimizer = Adam(model5.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model5)
    val_epoch(model5)

In [35]:
model6 = ClassificationModelv1(config, train_dataset.num_classes, freeze=False, weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model6 = model6.to(device)
optimizer = Adam(model6.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model6)
    val_epoch(model6)

In [36]:
model7 = ClassificationModelv1(config, train_dataset.num_classes, freeze=True, weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model7 = model7.to(device)
optimizer = Adam(model7.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model7)
    val_epoch(model7)

In [78]:
model8 = ClassificationModelv2(config, train_dataset.num_classes, freeze=False, weights_path=None)
model8 = model8.to(device)
optimizer = Adam(model8.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model8)
    val_epoch(model8)

In [79]:
model9 = ClassificationModelv2(config, train_dataset.num_classes, freeze=False, weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model9 = model9.to(device)
optimizer = Adam(model9.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model9)
    val_epoch(model9)

In [80]:
model10 = ClassificationModelv2(config, train_dataset.num_classes, freeze=True, weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model10 = model10.to(device)
optimizer = Adam(model10.parameters(), lr=1e-3, weight_decay=0)
criterion = nn.CrossEntropyLoss()

for i in range(5):
    print(f'Epoch {i}')
    train_epoch(model10)
    val_epoch(model10)

In [37]:
evaluate(model2, val_dataloader)

In [38]:
evaluate(model3, val_dataloader)

In [39]:
evaluate(model4, val_dataloader)

In [40]:
evaluate(model5, val_dataloader)

In [41]:
evaluate(model6, val_dataloader)

In [42]:
evaluate(model7, val_dataloader)

In [81]:
evaluate(model8, val_dataloader)

In [82]:
evaluate(model9, val_dataloader)

In [83]:
evaluate(model10, val_dataloader)

In [69]:
evaluate(model2, test_dataloader)

In [70]:
evaluate(model3, test_dataloader)

In [71]:
evaluate(model4, test_dataloader)

In [72]:
evaluate(model5, test_dataloader)

In [73]:
evaluate(model6, test_dataloader)

In [74]:
evaluate(model7, test_dataloader)

In [84]:
evaluate(model8, test_dataloader)

In [85]:
evaluate(model9, test_dataloader)

In [86]:
evaluate(model10, test_dataloader)

In [85]:
model = ClassificationModelv0(config, train_dataset.num_classes, freeze=False, 
                              weights_path='/home/mariia/AstroML/weights/2024-06-13-18-41-lwvpa5fm/weights-49.pth')
model = model.to(device)
optimizer = Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
criterion = nn.CrossEntropyLoss()

In [89]:
for i in range(10):
    print(f'Epoch {i}')
    train_epoch()
    val_epoch()

In [86]:
optimizer = Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
criterion = nn.CrossEntropyLoss()

In [74]:
for i in range(10):
    print(f'Epoch {i}')
    train_epoch()
    val_epoch()

In [39]:
train_epoch()

In [40]:
val_epoch()

In [41]:
for i in range(1, 10):
    print(f'Epoch {i}')
    train_epoch()

In [53]:
for i in range(10, 50):
    print(f'Epoch {i}')
    train_epoch()
    val_epoch()

In [45]:
evaluate()