In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
from transformers.models.time_series_transformer.modeling_time_series_transformer import TimeSeriesTransformerEncoder
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy import stats
import seaborn as sns
from sklearn.metrics import confusion_matrix

from core.dataset import MachoDataset
from core.trainer import PredictionTrainer, ClassificationTrainer
from core.model import ClassificationModel

In [3]:
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True

In [4]:
class CustomModel(nn.Module):
    def __init__(self, encoder, device, num_classes):
        super(CustomModel, self).__init__()
        
        self.device = device
        self.encoder = encoder
        self.num_classes = num_classes
        self.classifier = nn.Linear(self.encoder.config.d_model, num_classes)

    def forward(self, values, mask):
        encoder_outputs = self.encoder(inputs_embeds=values.unsqueeze(-1), attention_mask=mask)
        emb = encoder_outputs.last_hidden_state[:, 0, :]     # we will use the 1 element only, analog to CLS?
        res = self.classifier(emb)

        return res

In [5]:
def train_epoch():
    model.train()

    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0
    
    for batch in train_dataloader:
        _, times, _, values, _, mask, aux, labels = batch
        labels = labels.to(device)
        values, mask = values.to(device), mask.to(device)
    
        optimizer.zero_grad()
    
        logits = model(values, mask)
        loss = criterion(logits, labels)
        total_loss.append(loss.item())
    
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == labels).sum().item()
    
        total_correct_predictions += correct_predictions
        total_predictions += labels.size(0)
    
        loss.backward()
        optimizer.step()

    print(f'Total Loss: {round(sum(total_loss) / len(total_loss), 5)} Accuracy: {round(total_correct_predictions / total_predictions, 3)}', end=' ')

In [6]:
def val_epoch():
    model.eval()

    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in val_dataloader:
            _, times, _, values, _, mask, aux, labels = batch
            labels = labels.to(device)
            values, mask = values.to(device), mask.to(device)

            logits = model(values, mask)
            loss = criterion(logits, labels)
            total_loss.append(loss.item())

            probabilities = torch.nn.functional.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probabilities, dim=1)
            correct_predictions = (predicted_labels == labels).sum().item()

            total_correct_predictions += correct_predictions
            total_predictions += labels.size(0)

    print(f'Total Loss: {round(sum(total_loss) / len(total_loss), 5)} Accuracy: {round(total_correct_predictions / total_predictions, 3)}')

In [22]:
batch = next(iter(val_dataloader))
_, times, _, values, _, mask, aux, labels = batch
labels.dtype

In [7]:
def plot_confusion(all_true_labels, all_predicted_labels):
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_true_labels, all_predicted_labels)

    # Calculate percentage values for confusion matrix
    conf_matrix_percent = conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis]

    # Plot both confusion matrices side by side
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 7))

    # Plot absolute values confusion matrix
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=axes[0])
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    axes[0].set_title('Confusion Matrix - Absolute Values')

    # Plot percentage values confusion matrix
    sns.heatmap(conf_matrix_percent, annot=True, fmt='.2%', cmap='Blues', ax=axes[1])
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('True')
    axes[1].set_title('Confusion Matrix - Percentages')

In [8]:
config = {
    'random_seed': random_seed,
    'data_root': '/home/mrizhko/AML/contra_periodic/data/macho/',
    'balanced_data_root': '/home/mrizhko/AML/AstroML/data/macho-balanced/',
    'weights_path': '/home/mrizhko/AML/AstroML/weights/',

    # Time Series Transformer
    'lags': None,  # ?
    'distribution_output': 'normal',
    'num_static_real_features': 0,  # if 0 we don't use real features
    'num_time_features': 1,
    'd_model': 512,
    'decoder_layers': 4,
    'encoder_layers': 2,
    'dropout': 0,
    'encoder_layerdrop': 0,
    'decoder_layerdrop': 0,
    'attention_dropout': 0,
    'activation_dropout': 0,

    # Data
    'window_length': 200,
    'prediction_length': 0,  # 1 5 10 25 50

    # Training
    'batch_size': 256,
    'lr': 1e-3,
    'weight_decay': 0,
    'epochs_pre_training': 1000,
    'epochs_fine_tuning': 100,
    
    # Learning Rate Scheduler
    'factor': 0.3,
    'patience': 10,

    'mode': 'fine-tuning',  # 'pre-training' 'fine-tuning' 'both'
    'save_weights': False,
    'config_from_run': None,  # 'MeriDK/AstroML/qtun67bq'
}

In [9]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print('Using', device)

In [10]:
train_dataset = MachoDataset(config['data_root'], config['window_length'], mode='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

In [11]:
val_dataset = MachoDataset(config['data_root'], config['window_length'], mode='val')
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

In [12]:
len(train_dataset), len(val_dataset)

In [13]:
transformer_config = TimeSeriesTransformerConfig(
    prediction_length=10,
    context_length=200,
    num_time_features=config['num_time_features'],
    num_static_real_features=config['num_static_real_features'],
    encoder_layers=config['encoder_layers'],
    d_model=config['d_model'],
    distribution_output='normal',
    scaling=None,
    dropout=config['dropout'],
    encoder_layerdrop=config['encoder_layerdrop'],
    decoder_layerdrop=config['decoder_layerdrop'],
    attention_dropout=config['attention_dropout'],
    activation_dropout=config['activation_dropout']
)
transformer_config.feature_size = 1

In [14]:
encoder = TimeSeriesTransformerEncoder(transformer_config)

In [15]:
model = CustomModel(encoder, device, 8)

In [16]:
optimizer = AdamW(model.parameters(), lr=config['lr'])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=config['factor'], patience=config['patience'], verbose=True)
criterion = nn.CrossEntropyLoss()
model = model.to(device)

In [18]:
for i in range(10):
    print(f'Epoch {i}', end=' ')
    train_epoch()
    val_epoch()

In [61]:
model.eval()

y_pred = []
y_true = []

with torch.no_grad():
    for batch in val_dataloader:
        _, times, _, values, _, mask, aux, labels = batch
        values, mask = values.to(device), mask.to(device)
    
        logits = model(values, mask)
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)

        y_pred.append(predicted_labels)
        y_true.append(labels)

y_pred = torch.hstack(y_pred).cpu()
y_true = torch.hstack(y_true)

In [62]:
plot_confusion(y_true, y_pred)

In [38]:
_, times, _, values, _, mask, aux, labels = next(iter(train_dataloader))

In [39]:
times.shape, values.shape, mask.shape

In [40]:
values.unsqueeze(-1).shape

In [45]:
encoder_outputs = encoder(
    inputs_embeds=values.unsqueeze(-1),
    attention_mask=mask,
)

In [46]:
encoder_outputs.last_hidden_state.shape

In [58]:
model(values, mask).shape

In [18]:
embedder = TimeSeriesTransformerModel(transformer_config)

In [19]:
model = ClassificationModel(pretrained_model=embedder, device=device)


In [20]:
trainer = ClassificationTrainer(model=model, optimizer=optimizer, scheduler=scheduler, 
                                criterion=criterion, device=device)

In [21]:
trainer.train(train_dataloader, train_dataloader, epochs=1000)

In [22]:
val_dataset = MachoDataset(config['balanced_data_root'], config['prediction_length'], mode='val')
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

In [23]:
print(trainer.val_epoch(val_dataloader))
trainer.evaluate(val_dataloader)

In [24]:
print(trainer.val_epoch(train_dataloader))
trainer.evaluate(train_dataloader)