In [None]:
# pareto with multi-objective hyperparameter optimization on feature extractor and discriminator losses

In [36]:
import sys
import os
# import gc  # debug that memory leak
# import tracemalloc  # DEBUG THAT MEMORY LEAK
import psutil  # :( mem leak
import h5py
import copy
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
from  torch.nn.functional import one_hot
from sklearn.utils import shuffle
import torch
from torch import nn
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import torchinfo
import optuna
from datetime import datetime, timezone
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
dtype = torch.double

# Get functions from other notebooks
%run /tigress/kendrab/analysis-notebooks/torch_models/utils.ipynb
%run /tigress/kendrab/analysis-notebooks/preproc_utils.ipynb
%run /tigress/kendrab/analysis-notebooks/eval_utils.ipynb
%run /tigress/kendrab/analysis-notebooks/torch_models/import_mms_data.ipynb
%run /tigress/kendrab/analysis-notebooks/torch_models/adda_constructor.ipynb

Using cpu device


In [37]:
# parameters
mms_epochs = 1  # number of times to loop through the entire mms dataset (start with 1 lmao)
name='adda_1'

rank = os.environ.get("OMPI_COMM_WORLD_RANK")


In [38]:
# load a model
%run /tigress/kendrab/analysis-notebooks/torch_models/import_model.ipynb
batch_size = 11
# TODO?: enforce that key variables that need to exist later down the pipeline
# are populated by import_model?

In [39]:
# extract the model's features (before classification step)
# get_graph_node_names(model)
return_nodes = {'post_merge_layers.2' : 'features'}
feat_sim = create_feature_extractor(model, return_nodes=return_nodes)
mock_data = torch.ones((batch_size, 1, input_length), dtype=dtype)
tsum = torchinfo.summary(feat_sim, input_data = [mock_data for i in range(7)])
feat_shape = tsum.summary_list[-1].output_size
# extract the classifier part
all_classifier = nn.Sequential(*list(model.children())[-1][-2:])
torchinfo.summary(all_classifier)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Flatten: 1-1                           --
├─LazyLinear: 1-2                        2,387
Total params: 2,387
Trainable params: 2,387
Non-trainable params: 0

In [40]:
# Load the sim data
%run /tigress/kendrab/analysis-notebooks/torch_models/import_sim_data.ipynb

(425291, 1, 89)
(425291, 2, 11)
torch.float64


In [41]:
# Load the mms data locations and shuffle the files to not be chronological
global_mms_filenames = get_filenames()
debug_filename = "2021-08-18T04-48-00_2021-08-20T04-30-30.h5"

### Time to do some training 

In [42]:
def objective(trial):
    # timing stuff
    start = datetime.now(timezone.utc)  # for timing
    time_str = start.strftime("%H%M%S")
    date_str = start.strftime("%d-%m-%y")
    start_str = date_str + time_str

    # make suggestions
    mms_num_conv = trial.suggest_int('num_conv', 1, 2)  # for separate layers but also for end layer (so really there are num_conv*2 layers)
    mms_kp_limit = int(input_length**(1/(mms_num_conv+1)))  # a non-maximal upper bound on sizes to avoid running out of length
    print(f"kernel eqn limit kp={kp_limit}")
    min_pool_size=2
    mms_kernel_size = trial.suggest_int('kernel_size', 2, min(mms_kp_limit - min_pool_size, 10))  # max size 10 or lower
    mms_pool_size = trial.suggest_int('pool_size', min_pool_size, min(mms_kp_limit - mms_kernel_size, 5))
    mms_out_channels = trial.suggest_int('out_channels', 8, 56)  # like 'filters' in keras
    mms_learning_rate = trial.suggest_float('learning_rate', 0.001, 0.003, log=True)  # CANT CHANGE LOG NOW >:(
    mms_dropout_fraction = trial.suggest_float('dropout', 0, 0.3)    
    discrim_width = trial.suggest_int('discrim_width', 30, 200)
    discrim_length = trial.suggest_int('discrim_length', 1, 3)  # number of layer for discrim
    # build the domain adaptation structure
    discrim, feat_mms = create_adda(mms_num_conv, mms_kp_limit, mms_kernel_size, mms_pool_size, mms_out_channels, mms_learning_rate,
                                    mms_dropout_fraction, discrim_width, discrim_length, feat_shape)
    hyperparams_adda = {'mms_num_conv': mms_num_conv, 'mms_kp_limit': mms_kp_limit, 'mms_kernel_size':mms_kernel_size, 'mms_pool_size':mms_pool_size,
                        'mms_out_channels':mms_out_channels, 'mms_learning_rate':mms_learning_rate, 'mms_dropout_fraction':mms_dropout_fraction,
                        'discrim_width':discrim_width, 'discrim_length':discrim_length, 'feat_shape':feat_shape}

    loss_fn_disc = nn.BCEWithLogitsLoss()
    loss_fn_feat = nn.BCEWithLogitsLoss()
    optim_disc = torch.optim.Adam(discrim.parameters(),lr=mms_learning_rate)
    optim_feat = torch.optim.Adam(feat_mms.parameters(),lr=mms_learning_rate)
    sum_loss_feat = 0
    sum_loss_disc = 0

    mms_filenames = global_mms_filenames
    for epoch in range(mms_epochs):
        loss_feat = []
        loss_disc = []
        print(f"Starting Epoch {epoch+1}")
        mms_filenames = shuffle(mms_filenames)
        sim_iter = iter(sim_dl)  # to make sure we loop through the whole dataset before starting over even as we switch between mms files
        for i, mms_file in enumerate(mms_filenames):
            print(f"getting MMS file {mms_file}, number {i+1} of {len(mms_filenames)}")
            # get the mms data from the file
            process = psutil.Process()  # start debug loop thingy
#            print(process.memory_info().rss/1024/1024)
            mms_data_dict = get_mms_data(mms_file)
#            print(process.memory_info().rss/1024/1024)
            if len(mms_data_dict) == 0:  # data was not loaded, skip this file
                continue
            mms_dl = format_mms_data(mms_data_dict)
#            print(process.memory_info().rss/1024/1024)            
            training_step = train_adda(sim_dl, mms_dl, feat_sim, feat_mms, discrim, loss_fn_disc, loss_fn_feat, optim_disc,
                                       optim_feat, iter_source = sim_iter)
            sim_iter = training_step["iter_source"]
            loss_feat = training_step["loss_feat"]
            loss_disc = training_step["loss_disc"]

        sum_loss_feat = sum(loss_feat)
        sum_loss_disc = sum(loss_disc)
        
    # save if we are retrieving a specific trial
    if study.user_attrs['save']:
        print("Saving model...")  # DEBUG

        if rank is not None:
            time_str += f"_{rank}"  # differentiate between mpi ranks that started at same second
        log_file, _, _, file_start = generic_outputs_structure("/tigress/kendrab/analysis-notebooks/model_outs/",
                                                                name, date_str, time_str)
        # Dump information to file
        with open(log_file, 'w') as log:
            log.write(f"MMS model {name} domain adapted on {start_str}\n")
            log.write(f"using model file {model_file}\n")
            log.write(f"trial number: {trial.number}\n")
            log.write(f"Feature extractor loss: {sum_loss_feat}\n")
            log.write(f"Discriminator loss: {sum_loss_disc}\n")
            log.write("Hyperparameters:\n")
            for key in hyperparams_adda.keys():
                log.write(f"{key}\t\t{hyperparams_adda[key]}\n")
                
        # save the mms classifier
        class MMSModel(nn.Module):
            def __init__(self, feat_extract, classifier):
                super().__init__()
                self.feat_extract = feat_extract
                self.classifier = classifier
            
            def forward(self, bx, by, bz, ex, ey, ez, jy):
                features = self.feat_extract(bx, by, bz, ex, ey, ez, jy)
                logits = self.classifier(features)
                return logits
            
        mms_classifier = MMSModel(feat_mms, all_classifier)
        print(torchinfo.summary(mms_classifier))
        torch.save(mms_classifier.state_dict(), file_start+"mms_classifier_statedict.tar")
        # save the discriminator
        torch.save(discrim.state_dict(), file_start+"discriminator_statedict.tar")              
    
    # trial ended
    end = datetime.now(timezone.utc)
    print(f"trial execution time (s): {(end-start).total_seconds()}")
    return sum_loss_feat, sum_loss_disc

In [43]:
# assume study already made
""" study = optuna.create_study(study_name='adda_optim',storage="mysql+mysqldb://optunauser:Frikkenoptuna@stellar-intel.princeton.edu:47793/adda_1", directions=['minimize','minimize'])"""
# use hyperband pruner 
study = optuna.load_study(study_name='adda_optim',storage="mysql+mysqldb://optunauser:Frikkenoptuna@stellar-intel.princeton.edu:47793/adda_1")
# # regular training
# study.set_user_attr('save',False)
# study.optimize(objective, n_trials=10)
# bringing back the best one to train a new model
trial_num = study.user_attrs['knee_trial_num']
best_params = study.trials[trial_num].params
study.enqueue_trial(best_params, skip_if_exists=False)
study.set_user_attr('save',True)
study.optimize(objective, n_trials=4)
# objective(study.trials[trial_num])

kernel eqn limit kp=9
Starting Epoch 1
getting MMS file 2018-08-01T03-11-30_2018-08-01T13-27-00.h5, number 1 of 252




training on 425291 source samples and 11653 target samples
discrim loss: 1.0857857435195781, feat extract loss: 1.601543583387228, sample 550/11649
discrim loss: 1.2887329279343955, feat extract loss: 0.8786267068848208, sample 1100/11649
discrim loss: 1.2022937773258702, feat extract loss: 1.1551961545679934, sample 1650/11649
discrim loss: 1.3607939228796888, feat extract loss: 0.7027617843800104, sample 2200/11649
discrim loss: 1.1666834678496474, feat extract loss: 0.8065674928514768, sample 2750/11649
discrim loss: 1.502984491326018, feat extract loss: 1.187458661063554, sample 3300/11649
discrim loss: 0.5212991568514973, feat extract loss: 2.1163111132067893, sample 3850/11649
discrim loss: 0.859973149223936, feat extract loss: 1.2526019629726528, sample 4400/11649
discrim loss: 0.571939807182406, feat extract loss: 1.7479786991783504, sample 4950/11649
discrim loss: 0.8170270054001282, feat extract loss: 1.034019050211272, sample 5500/11649
discrim loss: 1.2192400769317169, feat

[I 2024-08-16 10:39:53,027] Trial 3511 finished with values: {'sum_loss_feat': 1.9338242634912497, 'sum_loss_disc': 0.5537149665738788} and parameters: {'num_conv': 2, 'kernel_size': 2, 'pool_size': 2, 'out_channels': 31, 'learning_rate': 0.001057022133105261, 'dropout': 0.04152485670069086, 'discrim_width': 88, 'discrim_length': 3}. 


Saving model...
Layer (type:depth-idx)                   Param #
MMSModel                                 --
├─MMSFeatExtract: 1-1                    --
│    └─Sequential: 2-1                   --
│    │    └─Conv1d: 3-1                  93
│    │    └─LeakyReLU: 3-2               --
│    │    └─AvgPool1d: 3-3               --
│    │    └─Dropout: 3-4                 --
│    │    └─Conv1d: 3-5                  1,953
│    │    └─LeakyReLU: 3-6               --
│    │    └─AvgPool1d: 3-7               --
│    │    └─Dropout: 3-8                 --
│    └─Sequential: 2-2                   --
│    │    └─Conv1d: 3-9                  93
│    │    └─LeakyReLU: 3-10              --
│    │    └─AvgPool1d: 3-11              --
│    │    └─Dropout: 3-12                --
│    │    └─Conv1d: 3-13                 1,953
│    │    └─LeakyReLU: 3-14              --
│    │    └─AvgPool1d: 3-15              --
│    │    └─Dropout: 3-16                --
│    └─Sequential: 2-3                   --
│    

In [44]:
# # plot the loss
# fig, ax = plt.subplots(2)
# ax[0].plot(loss_disc)
# ax[1].plot(loss_feat)
# ax[0].set(title="Discriminator loss", xlabel="training iteration", ylabel="loss")
# ax[1].set(title="MMS Feature Extractor loss", xlabel="training iteration", ylabel="loss")
# fig.savefig("/tigress/kendrab/analysis-notebooks/model_outs/scratchwork/training_losses.svg")  # TODO: save model and training info to its own folder
# plt.close(fig='all')