In [None]:
import os

import pandas as pd
import torch as ch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import json
from sklearn.metrics import matthews_corrcoef, average_precision_score, roc_auc_score, accuracy_score
from sklearn.preprocessing import StandardScaler
import pytorch_lightning as pl
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitEncoderRegressionModel
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers


In [None]:
%load_ext autoreload
%autoreload 2 

In [None]:
data_path = '/mnt/data1/klug/datasets/opsum/short_term_outcomes/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
# data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
# output_dir = '/Users/jk1/Downloads'
output_dir = '/home/klug/temp/tte_testing'

In [None]:
config_path = '/home/klug/temp/checkpoints_short_opsum_transformer_tte_20250109_235614_cv_1/tte_end_transformer_best_hyperparameters.csv'

In [None]:
use_gpu = True

In [None]:
hp_model_config = pd.read_csv(config_path)
hp_model_config = hp_model_config.to_dict(orient='records')[0]
classification_threshold = 6

In [None]:
splits = ch.load(os.path.join(data_path))

In [None]:
from prediction.short_term_outcome_prediction.timeseries_decomposition import prepare_subsequence_dataset

train_dataset, val_dataset = prepare_subsequence_dataset(splits[hp_model_config['best_cv_fold']], use_gpu=use_gpu, use_time_to_event=True)

In [None]:
model_config = {'batch_size': 256,
 'num_layers': 2,
 'model_dim': 256,
 'train_noise': 5.868386798073278e-05,
 'weight_decay': 1e-5,
 'dropout': 0.5,
 'num_head': 8,
'lr': 1e-3,  # or 1e-4
'n_lr_warm_up_steps': 500,  # Try 500 or 1000
 'grad_clip_value': 1,
 'early_stopping_step_limit': 10,
 'loss_function': 'log_cosh',
    'scheduler': 'exponential',
       }
save_model = False
model_config['best_cv_fold'] = hp_model_config['best_cv_fold']

In [None]:
from prediction.short_term_outcome_prediction.timeseries_decomposition import BucketBatchSampler
from pytorch_lightning import loggers as pl_loggers

accelerator = 'gpu' if use_gpu else 'cpu'

input_dim = train_dataset[0][0].shape[-1]
ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

model = OPSUMTransformer(
            input_dim=input_dim,
            num_layers=int(model_config['num_layers']),
            model_dim=int(model_config['model_dim']),
            dropout=int(model_config['dropout']),
            ff_dim=int(ff_dim),
            num_heads=int(model_config['num_head']),
            num_classes=1,
            max_dim=500,
            pos_encode_factor=pos_encode_factor
    )

train_bucket_sampler = BucketBatchSampler(train_dataset.idx_to_len_map, model_config['batch_size'])
train_loader = DataLoader(train_dataset, batch_sampler=train_bucket_sampler,
                          # shuffling is done in the bucket sampler
                          shuffle=False, drop_last=False)

val_bucket_sampler = BucketBatchSampler(val_dataset.idx_to_len_map, 1024)
val_loader = DataLoader(val_dataset, batch_sampler=val_bucket_sampler)

if save_model:
    checkpoint_callback = ModelCheckpoint(
                save_top_k=1,
                monitor="val_mae",
                mode="min",
                dirpath=output_dir,
                filename="tte_short_opsum_transformer_{epoch:02d}_{val_auroc:.4f}",
            )
    # save config to output_dir
    # save model config
    file_name = 'tte_short_opsum_transformer_config.json'
    with open(os.path.join(output_dir, file_name), 'w') as f:
        json.dump(model_config, f)

    callbacks = [checkpoint_callback]
else:
    callbacks = []

# prompt user for run name
run_name = input()

logger = DictLogger(0)
tb_logger = pl_loggers.TensorBoardLogger(save_dir=output_dir, name='logs', version=run_name)


module = LitEncoderRegressionModel(model, lr=model_config['lr'],
                                          wd=model_config['weight_decay'],
                                          train_noise=model_config['train_noise'],
                                            lr_warmup_steps=model_config['n_lr_warm_up_steps'],
                                   classification_threshold=classification_threshold,
                                   loss_function=model_config['loss_function'],
                                   scheduler=model_config['scheduler'],
                                   debug_mode=True
                                   )
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=50,
                     logger=[logger, tb_logger],
                     callbacks=callbacks,
                     log_every_n_steps=25, enable_checkpointing=True,
                     gradient_clip_val=model_config['grad_clip_value'])

In [None]:
# train whilst plotting the loss
trainer.fit(module, train_loader, val_loader)

In [None]:
import gc
# del model, module, trainer

gc.collect()
# Clear GPU memory
ch.cuda.empty_cache()
ch.cuda.synchronize()  # Ensures GPU operations complete

