In [None]:
import sys
sys.path.insert(0, '/tmp/pycharm_project_87')

In [None]:
import os
from functools import partial
from datetime import datetime
import torch as ch
from os import path
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import json

from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitModel


In [None]:
from torch.utils.data import TensorDataset
from sklearn.preprocessing import StandardScaler
from pytorch_lightning.callbacks.callback import Callback
try:
    from pytorch_lightning.loggers import LightningLoggerBase
except:
    from pytorch_lightning.loggers.logger import Logger as LightningLoggerBase


def prepare_dataset(scenario, balanced=False, aggregate=False, rescale=True, use_gpu=True):
    X_train, X_val, y_train, y_val = scenario
    scaler = StandardScaler()

    if rescale:
        X_train = scaler.fit_transform(X_train.reshape(-1, 84)).reshape(X_train.shape)
        X_val = scaler.transform(X_val.reshape(-1, 84)).reshape(X_val.shape)

    if balanced:
        X_train_neg = X_train[y_train == 0]
        X_train_pos = X_train[np.random.choice(np.where(y_train==1)[0], X_train_neg.shape[0])]
        X_train = np.concatenate([X_train_neg, X_train_pos])
        y_train = np.concatenate([np.zeros(X_train_neg.shape[0]), np.ones(X_train_pos.shape[0])])

    if use_gpu:
        train_dataset = TensorDataset(ch.from_numpy(X_train).cuda(), ch.from_numpy(y_train.astype(np.int32)).cuda())
        val_dataset = TensorDataset(ch.from_numpy(X_val).cuda(), ch.from_numpy(y_val.astype(np.int32)).cuda())
    else:
        train_dataset = TensorDataset(ch.from_numpy(X_train), ch.from_numpy(y_train.astype(np.int32)))
        val_dataset = TensorDataset(ch.from_numpy(X_val), ch.from_numpy(y_val.astype(np.int32)))
    return train_dataset, val_dataset

class DictLogger(LightningLoggerBase):
    """PyTorch Lightning `dict` logger."""

    def __init__(self, version):
        super(DictLogger, self).__init__()
        self.metrics = []
        self._version = version

    def log_metrics(self, metrics, step=None):
        self.metrics.append(metrics)

    @property
    def version(self):
        return self._version

    @property
    def experiment(self):
        """Return the experiment object associated with this logger."""

    def log_hyperparams(self, params):
        """
        Record hyperparameters.
        Args:
            params: :class:`~argparse.Namespace` containing the hyperparameters
        """

    @property
    def name(self):
        """Return the experiment name."""
        return 'optuna'


class MyEarlyStopping(Callback):

    best_so_far = 0
    last_improvement = 0

    def __init__(self):
        super().__init__()

    def on_validation_end(self, trainer, pl_module):
        logs = trainer.callback_metrics
        val_auroc = logs['val_auroc'].item()

        if val_auroc > self.best_so_far:
            self.last_improvement = 0
        else:
            self.last_improvement += 1

        print(self.last_improvement)
        trainer.should_stop = val_auroc < 0.75 * self.best_so_far or self.last_improvement > 10 or \
                    (trainer.current_epoch > 10 and val_auroc < 0.55)

        self.best_so_far = max(val_auroc, self.best_so_far)


In [None]:
# INPUT_FOLDER = '/Users/jk1/Downloads'
# SPLIT_FILE = 'train_data_splits_3M_mRS_0-2_ts0.8_rs42_ns5.pth'

INPUT_FOLDER = '/home/gl/gsu_prepro_01012023_233050/data_splits'
SPLIT_FILE = 'train_data_splits_3M_mRS_0-2_ts0.8_rs42_ns5.pth'

In [None]:
use_gpu = True

In [None]:
scenarios = ch.load(path.join(INPUT_FOLDER, SPLIT_FILE))

In [None]:
single_split_dataset = prepare_dataset(scenarios[0], use_gpu=use_gpu)

In [None]:
train_dataset, val_dataset =  single_split_dataset

In [None]:
input_dim = 84
bs = 16
num_layers = 6
model_dim = 1024
train_noise = 1e-4
wd = 0.0001
ff_factor = 2
ff_dim = ff_factor * model_dim
dropout = 0.4
num_heads = 16
pos_encode_factor = 1
lr = 0.0004
n_lr_warm_up_steps = 5
grad_clip = 0.05

In [None]:
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, drop_last=True )
val_loader = DataLoader(val_dataset, batch_size=1024)
logger = DictLogger(0)

In [None]:
model = OPSUMTransformer(
            input_dim=input_dim,
            num_layers=num_layers,
            model_dim=model_dim,
            dropout=dropout,
            ff_dim=ff_dim,
            num_heads=num_heads,
            num_classes=1,
            max_dim=500,
            pos_encode_factor=pos_encode_factor
        )


In [None]:
module = LitModel(model, lr, wd, train_noise, lr_warmup_steps=n_lr_warm_up_steps)

if use_gpu:
    accelerator = 'gpu'
else:
    accelerator = 'cpu'

trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=1000, logger=logger,
                     log_every_n_steps = 25, enable_checkpointing=True,
                     callbacks=[MyEarlyStopping()], gradient_clip_val=grad_clip)


In [None]:
trainer.fit(model=module, train_dataloaders=train_loader, val_dataloaders=val_loader)