In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import InformerConfig, InformerModel 
from transformers.models.informer.modeling_informer import InformerEncoder
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy import stats
import json 
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sns
from datetime import datetime

from pathlib import Path
from core.multimodal.dataset import collate_fn, ASASSNVarStarDataset
from core.multimodal.dataset2 import VGDataset
from functools import partial
import matplotlib.pyplot as plt
from models.Informer import Informer

In [3]:
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True

In [4]:
data_root = '/home/mariia/AstroML/data/asassn'
vg_file = 'vg_combined.csv'

train_dataset = VGDataset(data_root, vg_file, split='train')
val_dataset = VGDataset(data_root, vg_file, split='val')

Removing objects without periods... 848299 objects left.
Removing objects that have more than 50000 or less than 5000 samples... 368345 objects left.
train split is selected: 294670 objects left.
Removing objects without periods... 848299 objects left.
Removing objects that have more than 50000 or less than 5000 samples... 368345 objects left.
val split is selected: 36833 objects left.


In [14]:
train_dataset.reader_g.getmember('g_band_lcs/ASASSN-V_J075227.80-863500.9.dat')

<TarInfo 'g_band_lcs/ASASSN-V_J075227.80-863500.9.dat' at 0x7f21293d0f40>

In [17]:
train_dataset.reader_g.extractfile('g_band_lcs/ASASSN-V_J075227.80-863500.9.dat')

<ExFileObject name='/home/mariia/AstroML/data/asassn/g_band_lcs-001.tar'>

In [15]:
'g_band_lcs/ASASSN-V_J075227.80-863500.9.dat' in train_dataset.reader_g.getnames()

True

In [16]:
len(train_dataset.reader_g.getnames())

378862

In [5]:
X, mask, y = train_dataset[0]

In [6]:
X.dtype, mask.dtype, X.shape, mask.shape

(dtype('float32'), dtype('float32'), (200, 3), (200,))

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [10]:
def train_epoch():
    model.train()

    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0
    
    for (X, mask, y) in tqdm(train_dataloader):
        X, mask, y = X.to(device), mask.to(device), y.to(device)
    
        optimizer.zero_grad()
    
        logits = model(X[:, :, 1:], mask)
        loss = criterion(logits, y)
        total_loss.append(loss.item())
    
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == y).sum().item()
    
        total_correct_predictions += correct_predictions
        total_predictions += y.size(0)
    
        loss.backward()
        optimizer.step()

    print(f'Train Total Loss: {round(sum(total_loss) / len(total_loss), 5)} Accuracy: {round(total_correct_predictions / total_predictions, 3)}')

In [11]:
def val_epoch():
    model.eval()

    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for (X, mask, y) in tqdm(val_dataloader):
            X, mask, y = X.to(device), mask.to(device), y.to(device)

            logits = model(X[:, :, 1:], mask)
            loss = criterion(logits, y)
            total_loss.append(loss.item())

            probabilities = torch.nn.functional.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probabilities, dim=1)
            correct_predictions = (predicted_labels == y).sum().item()

            total_correct_predictions += correct_predictions
            total_predictions += y.size(0)

    print(f'Val Total Loss: {round(sum(total_loss) / len(total_loss), 5)} Accuracy: {round(total_correct_predictions / total_predictions, 3)}')

In [19]:
model = Informer(enc_in=2, d_model=64, dropout=0.1, factor=1, output_attention=False, n_heads=4, d_ff=512,
                 activation='gelu', e_layers=2, seq_len=200, num_class=len(train_dataset.target2id))
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print('Using', device)

Using cuda:2


In [24]:
scheduler = ReduceLROnPlateau(optimizer)
scheduler.get_last_lr()

[0.0001]

In [34]:
for i in range(10):
    scheduler.step(0.1)
    print(scheduler.get_last_lr())

[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]
[1.0000000000000004e-08]


[0.0001]

In [28]:
x = torch.ones((612, 200, 2)).to(device)
m = torch.ones((612, 200)).to(device)

In [31]:
np.ones((10))

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [29]:
model(x, m)

tensor([[ 0.0029,  0.0319,  0.5729,  ..., -0.3495,  0.0020, -0.4262],
        [ 0.1701,  0.0945,  0.4277,  ..., -0.2890,  0.3082, -0.5674],
        [ 0.0536, -0.2898,  0.0458,  ..., -0.2095,  0.0891, -0.4013],
        ...,
        [-0.1805,  0.0311,  0.4741,  ...,  0.0406, -0.2420, -0.7261],
        [-0.0577,  0.0041,  0.4081,  ..., -0.2666, -0.0938, -0.6646],
        [ 0.2513, -0.0267,  0.3582,  ..., -0.2757,  0.2092, -0.5088]],
       device='cuda:2', grad_fn=<AddmmBackward0>)

In [13]:
for i in range(10):
    print(f'Epoch {i}')
    train_epoch()
    val_epoch()

Epoch 0


  1%|▍                                                     | 4/576 [01:06<2:38:19, 16.61s/it]


KeyboardInterrupt: 

In [6]:
# datapath = Path('../data/asassn')
# ds_train = ASASSNVarStarDataset(datapath, mode='train', verbose=True, only_periodic=True, recalc_period=False, 
#                                 prime=True, use_bands=['v'], only_sources_with_spectra=False, return_phased=True, 
#                                 fill_value=0, max_samples=100)
# ds_val = ASASSNVarStarDataset(datapath, mode='val', verbose=True, only_periodic=True, recalc_period=False, 
#                               prime=True, use_bands=['v'], only_sources_with_spectra=False, return_phased=True, 
#                               fill_value=0, max_samples=100)

In [9]:
# context_length = 200

# no_spectra_data_keys = ['lcs', 'classes']
# no_spectra_collate_fn = partial(collate_fn, data_keys=no_spectra_data_keys, fill_value=0)

# train_dataloader = DataLoader(ds_train, batch_size=16, shuffle=True, num_workers=0, 
#                               collate_fn=no_spectra_collate_fn)
# val_dataloader = DataLoader(ds_val, batch_size=16, shuffle=False, collate_fn=no_spectra_collate_fn)

In [13]:
def preprocess_batch(batch, masks):
    lcs, classes = batch
    lcs_mask, classes_mask = masks

    # shape now [128, 1, 3, 759], make [128, 3, 759] 
    X = lcs[:, 0, :, :]
    
    # change axises, shape now [128, 3, 759], make [128, 759, 3]
    X = X.transpose(1, 2)
    
    # since mask is the same for time flux and flux err we can make it 2D
    mask = lcs_mask[:, 0, 0, :]

    # context length 200, crop X and MASK if longer, pad if shorter
    if X.shape[1] < context_length:
        X_padding = (0, 0, 0, context_length - X.shape[1], 0, 0)
        mask_padding = (0, context_length - X.shape[1])
        X = F.pad(X, X_padding)
        mask = F.pad(mask, mask_padding, value=True)
    else:
        X = X[:, :context_length, :]
        mask = mask[:, :context_length]

    # the last dimention is (time, flux, flux_err), sort it based on time
    sort_indices = torch.argsort(X[:, :, 0], dim=1)
    sorted_X = torch.zeros_like(X)
    
    for i in range(X.shape[0]):
        sorted_X[i] = X[i, sort_indices[i]]
    
    # rearange indexes for masks as well
    sorted_mask = torch.zeros_like(mask)
    
    for i in range(mask.shape[0]):
        sorted_mask[i] = mask[i, sort_indices[i]]

    # mask should be 1 for values that are observed and 0 for values that are missing
    sorted_mask = 1 - sorted_mask.int()

    # read scales
    with open('scales.json', 'r') as f:
        scales = json.load(f)
        mean, std = scales['v']['mean'], scales['v']['std']

    # scale X
    sorted_X[:, :, 1] = (sorted_X[:, :, 1] - mean) / std
    sorted_X[:, :, 2] = sorted_X[:, :, 2] / std

    # reshape classes to be 1D vector and convert from float to int
    classes = classes[:, 0]
    classes = classes.long()
    
    return sorted_X, sorted_mask, classes

In [17]:
def plot_confusion(all_true_labels, all_predicted_labels):
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_true_labels, all_predicted_labels)

    # Calculate percentage values for confusion matrix
    conf_matrix_percent = 100 * conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis]

    # Plot both confusion matrices side by side
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 20))

    # Plot absolute values confusion matrix
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=axes[0])
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    axes[0].set_title('Confusion Matrix - Absolute Values')

    # Plot percentage values confusion matrix
    sns.heatmap(conf_matrix_percent, annot=True, fmt='.0f', cmap='Blues', ax=axes[1])
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('True')
    axes[1].set_title('Confusion Matrix - Percentages')

In [23]:
model = Informer(enc_in=2, d_model=64, dropout=0.1, factor=1, output_attention=False, n_heads=4, d_ff=512,
                 activation='gelu', e_layers=2, seq_len=200, num_class=len(ds_train.target_lookup))
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print('Using', device)

Using cuda:3


In [24]:
for i in range(10):
    print(f'Epoch {i}')
    train_epoch()
    val_epoch()

Epoch 0


100%|████████████████████████████████████████████████████████| 80/80 [00:26<00:00,  3.06it/s]


Train Total Loss: 2.77588 Accuracy: 0.101


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.70it/s]


Val Total Loss: 2.76368 Accuracy: 0.094
Epoch 1


100%|████████████████████████████████████████████████████████| 80/80 [00:23<00:00,  3.46it/s]


Train Total Loss: 2.68069 Accuracy: 0.119


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.71it/s]


Val Total Loss: 2.69956 Accuracy: 0.138
Epoch 2


100%|████████████████████████████████████████████████████████| 80/80 [00:23<00:00,  3.43it/s]


Train Total Loss: 2.6414 Accuracy: 0.134


100%|████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  3.14it/s]


Val Total Loss: 2.70996 Accuracy: 0.144
Epoch 3


100%|████████████████████████████████████████████████████████| 80/80 [00:23<00:00,  3.38it/s]


Train Total Loss: 2.59221 Accuracy: 0.148


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.59it/s]


Val Total Loss: 2.71039 Accuracy: 0.163
Epoch 4


100%|████████████████████████████████████████████████████████| 80/80 [00:23<00:00,  3.44it/s]


Train Total Loss: 2.55976 Accuracy: 0.162


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.05it/s]


Val Total Loss: 2.71159 Accuracy: 0.144
Epoch 5


100%|████████████████████████████████████████████████████████| 80/80 [00:24<00:00,  3.28it/s]


Train Total Loss: 2.54449 Accuracy: 0.166


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.09it/s]


Val Total Loss: 2.6828 Accuracy: 0.144
Epoch 6


100%|████████████████████████████████████████████████████████| 80/80 [00:24<00:00,  3.25it/s]


Train Total Loss: 2.53139 Accuracy: 0.169


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.12it/s]


Val Total Loss: 2.68703 Accuracy: 0.163
Epoch 7


100%|████████████████████████████████████████████████████████| 80/80 [00:22<00:00,  3.60it/s]


Train Total Loss: 2.49137 Accuracy: 0.182


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.42it/s]


Val Total Loss: 2.64842 Accuracy: 0.144
Epoch 8


100%|████████████████████████████████████████████████████████| 80/80 [00:23<00:00,  3.41it/s]


Train Total Loss: 2.46742 Accuracy: 0.195


100%|████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.37it/s]


Val Total Loss: 2.6461 Accuracy: 0.175
Epoch 9


100%|████████████████████████████████████████████████████████| 80/80 [00:23<00:00,  3.35it/s]


Train Total Loss: 2.45454 Accuracy: 0.192


100%|████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  3.15it/s]

Val Total Loss: 2.63687 Accuracy: 0.144





In [8]:
batch, masks = next(iter(train_dataloader))
X, mask, y = preprocess_batch(batch, masks)

In [9]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print('Using', device)

Using cuda:2


In [19]:
model = model.to(device)
X, mask = X.to(device), mask.to(device)

In [12]:
X.shape

torch.Size([512, 200, 3])

In [20]:
with torch.no_grad():
    output = model(X[:, :, 1:], mask)

In [21]:
output.shape

torch.Size([512, 16])