In [3]:

import sys
sys.path.append('/gpfs/data/fs71925/dspringer1/Projects/AnaContML/src/')
import torch
from torch.utils.data import DataLoader
#sys.path.insert(1, '../src/');
import load_data
import datetime
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.plugins.environments import LightningEnvironment
import json
import os
import numpy as np
import h5py



def create_datasets(config):
    data = np.load(config["PATH_TRAIN"])
    train, validation = torch.utils.data.random_split(data, [int(data.__len__()*config["SPLIT"]), int(data.__len__())-int(data.__len__()*config["SPLIT"])], generator=torch.Generator().manual_seed(42)) 
    return train, validation


def train():
    ### JSON File contains full information about entire run (model, data, hyperparameters)
    ### TODO 
    MODEL_NAME = "GNN_1_base"
    config = json.load(open('/gpfs/data/fs71925/dspringer1/Projects/AnaContML/_runs/confmod_graph_neural_network_MIT_w100_n100_skew1.json'))[MODEL_NAME]
    print(config)

    ''' Dataloading '''
    train_data, validation_data = create_datasets(config)
    train_data = np.array(train_data)
    validation_data = np.array(validation_data)

    ### > Single HDF5 file containing training and validation data 
    ld = __import__("load_data", fromlist=['object'])
    # data_set = load_data.Dataset_ae(config)
    train_set = getattr(ld, config["DATA_LOADER"])(config, train_data)
    validation_set = getattr(ld, config["DATA_LOADER"])(config, validation_data)

    train_dataloader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True)
    validation_dataloader = DataLoader(validation_set, batch_size=config["batch_size"], shuffle=True)


    ''' Model setup '''
    wrapers = __import__("wrappers.wrapers", fromlist=['object'])#.wrapers
    model = getattr(wrapers, config["MODEL_WRAPER"])(config)

    ''' Model loading from save file '''
    if config["continue"] == True:
        SAVEPATH = config["SAVEPATH"]
        checkpoint = torch.load(SAVEPATH)
        model.load_state_dict(checkpoint['state_dict'])
        print(" >>> Loaded checkpoint")

    
    ''' Logging and saving '''
    DATA_NAME = os.path.splitext(os.path.basename(config["PATH_TRAIN"]))[0]
    print(" TRAIN DATA (slurm relevance) ")
    print(config["PATH_TRAIN"])
    print(DATA_NAME)
    print(MODEL_NAME)
    
    PATH = f"/gpfs/data/fs71925/dspringer1/Projects/AnaContML/_runs/saves_MIT_w{config['omega_steps']}/"
    CONFIGURATION = f"{PATH}/{DATA_NAME}_{config['omega_steps']}/save_{config['MODEL_NAME']}_BS{config['batch_size']}_{datetime.datetime.now().date()}"
    logger = TensorBoardLogger(PATH, name=CONFIGURATION)
    
    early_stop_callback = EarlyStopping(monitor="val_loss", mode="min", min_delta=0.00, patience=20, verbose=False)
    checkpoint_callback = ModelCheckpoint(save_top_k=-1)

    
    # ### '''Define (pytorch_lightning) Trainer '''
    # ### > SLURM Training
    # trainer = pl.Trainer(max_epochs=config["epochs"], 
    #                     accelerator=config["device_type"], 
    #                     devices=config["devices"], 
    #                     num_nodes=config["num_nodes"], 
    #                     #strategy='ddp', 
    #                     strategy='ddp_find_unused_parameters_true',
    #                     logger=logger
    #                     # callbacks=[checkpoint_callback]
    #                     )
    # ### > Jupyter Notebook Training
    trainer = pl.Trainer(max_epochs=config["epochs"], 
                         accelerator='gpu', 
                         devices=1, 
                         strategy='auto', 
                         logger=logger, 
                         # log_every_n_steps=1, 
                         plugins=[LightningEnvironment()], 
                         callbacks=[checkpoint_callback]
                        )

    # ### > Jupyter Notebook CPU Training
    # trainer = pl.Trainer(max_epochs=20, accelerator='cpu', devices=1, strategy='auto', logger=logger, plugins=[LightningEnvironment()])
    
    # ''' Train '''
    trainer.fit(model, train_dataloader, validation_dataloader)
    # trainer.fit(model, train_dataloader)
    
    
    # ### ''' Saving configuration file into log folder ''' 
    LOGDIR = trainer.log_dir
    json_object = json.dumps(config, indent=4)
    with open(LOGDIR+"/config.json", "w") as outfile:
        outfile.write(json_object)




def main():
    train()

if __name__ == '__main__':
    main()
# %%



{'MODEL_NAME': 'GNN_1_base', 'MODEL_WRAPER': 'model_wraper_gnn', 'PATH_TRAIN': '/gpfs/data/fs71925/dspringer1/Projects/AnaContML/data_2025_new/w_max_10c0_w_steps_100/n_peaks_2_sigma_0c1_0c9/lambda_20c0/100000_sym_0_asym_MIT_synthetic_skewed/noise_level_0c001_noisy_samples_10/symmetric.npy', 'PATH_VEC': '/gpfs/data/fs71925/dspringer1/Projects/AnaContML/data_2025_new/ctqmc_kernels_100/lambda_200/v.npy', 'SPLIT': 0.8, 'DATA_LOADER': 'Dataset_graph_InvPro', 'continue': True, 'SAVEPATH': '/gpfs/data/fs71925/dspringer1/Projects/AnaContML/_runs/saves_MIT_w100/symmetric_100/save_GNN_1_base_BS100_2025-11-28/version_2/checkpoints/epoch=0-step=8000.ckpt', 'n_nodes': 100, 'weird': False, 'batch_size': 100, 'learning_rate': 5e-05, 'weight_decay': 1e-08, 'epochs': 10, 'omega_steps': 100, 'tau_steps': 100, 'out_dim': 100, 'message_in_dim': 400, 'message_hidden_dim': 96, 'message_out_dim': 100, 'update_hidden_dim': 96, 'update_in_dim': 400, 'update_out_dim': 200, 'pre_pool_hidden_dim': 96, 'pre_pool_o

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params | Mode 
-----------------------------------------------------
0 | model         | GNN_1_base | 870 K  | train
1 | criterion_mse | MSELoss    | 0      | train
-----------------------------------------------------
870 K     Trainable params
0         Non-trainable params
870 K     Total params
3.480     Total estimated model params size (MB)
107       Modules in train mode
0         Modules in eval mode


 >>> Loaded checkpoint
 TRAIN DATA (slurm relevance) 
/gpfs/data/fs71925/dspringer1/Projects/AnaContML/data_2025_new/w_max_10c0_w_steps_100/n_peaks_2_sigma_0c1_0c9/lambda_20c0/100000_sym_0_asym_MIT_synthetic_skewed/noise_level_0c001_noisy_samples_10/symmetric.npy
symmetric
GNN_1_base
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/gpfs/data/fs71925/dspringer1/XInstalls/anaconda/envs/ml_p310/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:484: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/gpfs/data/fs71925/dspringer1/XInstalls/anaconda/envs/ml_p310/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


                                                                           

/gpfs/data/fs71925/dspringer1/XInstalls/anaconda/envs/ml_p310/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Epoch 0:   7%|â–‹         | 576/8000 [24:39<5:17:46,  0.39it/s, v_num=0]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
print("TEST")