In [None]:
import os
import pandas as pd
import torch as ch
import numpy as np
import json
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitModel
from prediction.short_term_outcome_prediction.timeseries_decomposition import BucketBatchSampler
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

In [None]:
%load_ext autoreload
%autoreload 2 

In [None]:
data_path = '/mnt/data1/klug/datasets/opsum/short_term_outcomes/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
# data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
# output_dir = '/Users/jk1/Downloads'
output_dir = '/home/klug/temp/enc_testing'

In [None]:
config_path = '/home/klug/temp/checkpoints_short_opsum_transformer_20240814_073845_cv_0/end_transformer_best_hyperparameters.csv'

In [None]:
use_gpu = True

In [None]:
hp_model_config = pd.read_csv(config_path)
hp_model_config = hp_model_config.to_dict(orient='records')[0]
classification_threshold = 6
imbalance_factor = 62

In [None]:
model_config = {'batch_size': 256,
 'num_layers': 2,
 'model_dim': 256,
 'train_noise': 5.868386798073278e-05,
 'weight_decay': 5e-4,
 'dropout': 0.5,
 'num_head': 32,
 'lr': 1e-5,
 'n_lr_warm_up_steps': 0,
 'grad_clip_value': 1,
                'alpha': 0.25,
                'gamma': 2.0,
 'early_stopping_step_limit': 10,
                'scheduler': 'exponential',
                'loss_function': 'focal',
                       }


In [None]:
splits = ch.load(os.path.join(data_path))

In [None]:
from prediction.short_term_outcome_prediction.timeseries_decomposition import prepare_subsequence_dataset

train_dataset, val_dataset = prepare_subsequence_dataset(splits[hp_model_config['best_cv_fold']], use_gpu=use_gpu, use_time_to_event=False)

In [None]:
save_model = True

In [None]:
accelerator = 'gpu' if use_gpu else 'cpu'

input_dim = train_dataset[0][0].shape[-1]
ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

model = OPSUMTransformer(
            input_dim=input_dim,
            num_layers=int(model_config['num_layers']),
            model_dim=int(model_config['model_dim']),
            dropout=float(model_config['dropout']),
            ff_dim=int(ff_dim),
            num_heads=int(model_config['num_head']),
            num_classes=1,
            max_dim=500,
            pos_encode_factor=pos_encode_factor
    )


train_bucket_sampler = BucketBatchSampler(train_dataset.idx_to_len_map, model_config['batch_size'])
train_loader = DataLoader(train_dataset, batch_sampler=train_bucket_sampler,
                          # shuffling is done in the bucket sampler
                          shuffle=False, drop_last=False)

val_bucket_sampler = BucketBatchSampler(val_dataset.idx_to_len_map, 1024)
val_loader = DataLoader(val_dataset, batch_sampler=val_bucket_sampler)

run_name = 'dim_256_lay2'

if save_model:
    checkpoint_callback = ModelCheckpoint(
                save_top_k=1,
                monitor="val_auroc",
                mode="max",
                dirpath=output_dir,
                filename="short_opsum_transformer_{epoch:02d}_{val_auroc:.4f}",
            )
    # save config to output_dir
    # save model config
    file_name = 'short_opsum_transformer_{epoch:02d}_{val_auroc:.4f}.json'
    with open(os.path.join(output_dir, file_name), 'w') as f:
        json.dump(model_config, f)

    callbacks = [checkpoint_callback]
else:
    callbacks = []

logger = DictLogger(0)
tb_logger = pl_loggers.TensorBoardLogger(save_dir=output_dir, name='logs', version=run_name)

module = LitModel(model, lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'], lr_warmup_steps=model_config['n_lr_warm_up_steps'],
                                   imbalance_factor=ch.tensor(imbalance_factor).float(),
                  loss_function=model_config['loss_function'],
                                      alpha=model_config['alpha'],
                                      gamma=model_config['gamma'],
                                      scheduler=model_config['scheduler'],
                  debug_mode=True
                                   )
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=50,
                     logger=[logger, tb_logger],
                        callbacks=callbacks,
                     log_every_n_steps=25, enable_checkpointing=True,
                     gradient_clip_val=model_config['grad_clip_value'])

In [None]:
# train whilst plotting the loss
trainer.fit(module, train_loader, val_loader)

In [None]:
# changes made (manual tuning):
# - grad clip to 1
# - lr to 1e-5
# - wd to 5e-4
# - dropout to 0.5
# - loss function to focal
# - num layers to 2
# - model_dim to 256
# - num_head 32


In [None]:
import gc
del model, module, trainer

gc.collect()
# Clear GPU memory
ch.cuda.empty_cache()
ch.cuda.synchronize()  # Ensures GPU operations complete

