In [1]:
# to continue training a good model
# for now: /tigress/kendrab/analysis-notebooks/model_outs/20-09-24/adda_ndb190550_95mms_classifier_statedict.tar

In [2]:
import sys
import os
import h5py
import copy
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
from  torch.nn.functional import one_hot
from sklearn.utils import shuffle
import torch
from torch import nn
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import torchinfo
import optuna
from datetime import datetime, timezone
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
dtype = torch.double

# Get functions from other notebooks
%run /tigress/kendrab/analysis-notebooks/torch_models/utils.ipynb
%run /tigress/kendrab/analysis-notebooks/preproc_utils.ipynb
%run /tigress/kendrab/analysis-notebooks/eval_utils.ipynb
%run /tigress/kendrab/analysis-notebooks/torch_models/import_mms_data.ipynb
%run /tigress/kendrab/analysis-notebooks/torch_models/adda_constructor.ipynb
%run /tigress/kendrab/analysis-notebooks/torch_models/ndb.ipynb

Using cpu device
Using cpu device


In [3]:
mms_classifier_path = "/tigress/kendrab/analysis-notebooks/model_outs/20-09-24/adda_ndb190550_95mms_classifier_statedict.tar"
discriminator_path = "/tigress/kendrab/analysis-notebooks/model_outs/20-09-24/adda_ndb190550_95discriminator_statedict.tar"
# some sim model hyperparameters
padding_length = 39  # amount of data on each side of each segment for additional info
stride = 11  # size (and therefore spacing) of each segment
input_length = stride + 2*padding_length
batch_size = 11

In [4]:
# parameters
mms_epochs = 1  # number of times to loop through the entire mms dataset (start with 1 lmao)
ndb_samples = 10000
name='adda_ndb'

# model parameters
mms_kp_limit = 9  # not neeeded when loading from file
mms_num_conv = 1
mms_kernel_size = 3
mms_pool_size = 4
mms_out_channels = 40
mms_learning_rate = 0.0012094769607738786
mms_dropout_fraction = 0.03457536835651724
discrim_width = 132
discrim_length = 3

rank = os.environ.get("OMPI_COMM_WORLD_RANK")

In [5]:
# load a model
%run /tigress/kendrab/analysis-notebooks/torch_models/import_model.ipynb
batch_size = 11
# TODO?: enforce that key variables that need to exist later down the pipeline
# are populated by import_model?



In [6]:
# extract the model's features (before classification step)
# get_graph_node_names(model)
return_nodes = {'post_merge_layers.2' : 'features'}
feat_sim = create_feature_extractor(model, return_nodes=return_nodes)
mock_data = torch.ones((batch_size, 1, input_length), dtype=dtype)
tsum = torchinfo.summary(feat_sim, input_data = [mock_data for i in range(7)])
feat_shape = tsum.summary_list[-1].output_size
# extract the classifier part
all_classifier = nn.Sequential(*list(model.children())[-1][-2:])

In [7]:
# Load the sim data
%run /tigress/kendrab/analysis-notebooks/torch_models/import_sim_data.ipynb

(425291, 1, 89)
(425291, 2, 11)
torch.float64


In [8]:
# Load the mms data locations and shuffle the files to not be chronological
global_mms_filenames = get_filenames()
debug_filename = "2021-08-18T04-48-00_2021-08-20T04-30-30.h5"

In [9]:
# timing stuff
start = datetime.now(timezone.utc)  # for timing
time_str = start.strftime("%H%M%S")
date_str = start.strftime("%d-%m-%y")
start_str = date_str + time_str

In [10]:
# save the mms classifier
class MMSModel(nn.Module):
    def __init__(self, feat_extract, classifier):
        super().__init__()
        self.feat_extract = feat_extract
        self.classifier = classifier

    def forward(self, bx, by, bz, ex, ey, ez, jy):
        features = self.feat_extract(bx, by, bz, ex, ey, ez, jy)
        logits = self.classifier(features)
        return logits

In [11]:
# rebuild the domain adaptation structure
discrim, feat_mms = create_adda(mms_num_conv, mms_kp_limit, mms_kernel_size, mms_pool_size, mms_out_channels, mms_learning_rate,
                                mms_dropout_fraction, discrim_width, discrim_length, feat_shape)
discrim.load_state_dict(torch.load(discriminator_path))
mms_classifier = MMSModel(feat_mms, all_classifier)
mms_classifier.load_state_dict(torch.load(mms_classifier_path))
hyperparams_adda = {'num_conv': mms_num_conv, 'kp_limit': mms_kp_limit, 'kernel_size':mms_kernel_size, 'pool_size':mms_pool_size,
                    'out_channels':mms_out_channels, 'learning_rate':mms_learning_rate, 'dropout_fraction':mms_dropout_fraction,
                    'discrim_width':discrim_width, 'discrim_length':discrim_length, 'feat_shape':feat_shape}

loss_fn_disc = nn.BCEWithLogitsLoss()
loss_fn_feat = nn.BCEWithLogitsLoss()
optim_disc = torch.optim.Adam(discrim.parameters(),lr=learning_rate)
optim_feat = torch.optim.Adam(feat_mms.parameters(),lr=learning_rate)
sum_loss_feat = 0
sum_loss_disc = 0

mms_filenames = global_mms_filenames
for epoch in range(mms_epochs):
    loss_feat = []
    loss_disc = []
    print(f"Starting Epoch {epoch+1}")
    mms_filenames = shuffle(mms_filenames)
    sim_iter = iter(sim_dl)  # to make sure we loop through the whole dataset before starting over even as we switch between mms files
    for i, mms_file in enumerate(mms_filenames):
        print(f"getting MMS file {mms_file}, number {i+1} of {len(mms_filenames)}")
        # get the mms data from the file
        mms_data_dict = get_mms_data(mms_file)
        if len(mms_data_dict) == 0:  # data was not loaded, skip this file
            continue
        mms_dl = format_mms_data(mms_data_dict)         
        training_step = train_adda(sim_dl, mms_dl, feat_sim, feat_mms, discrim, loss_fn_disc, loss_fn_feat, optim_disc,
                                   optim_feat, iter_source = sim_iter)
        sim_iter = training_step["iter_source"]
        loss_feat += training_step["loss_feat"]
        loss_disc += training_step["loss_disc"]

    sum_loss_feat = sum(loss_feat)
    sum_loss_disc = sum(loss_disc)
# calculate ndb score
mms_features, _ = get_mms_features(feature_extractor=feat_mms, n=ndb_samples)
ndb_sim_dl = DataLoader(sim_dset, batch_size = ndb_samples, shuffle=True, drop_last=True)
ndb_sim_samples = next(iter(ndb_sim_dl))
sim_features = get_sim_features(feature_extractor = feat_sim, samples = ndb_sim_samples, n = ndb_samples)["features"]
mms_features_flat = torch.flatten(mms_features, start_dim=1).detach()
sim_features_flat = torch.flatten(sim_features, start_dim=1).detach()
ndb = ndb_score(sim_features_flat, mms_features_flat)

# save if we are retrieving a specific trial
print("Saving model...")  # DEBUG

if rank is not None:
    time_str += f"_{rank}"  # differentiate between mpi ranks that started at same second
log_file, _, _, file_start = generic_outputs_structure("/tigress/kendrab/analysis-notebooks/model_outs/",
                                                        name, date_str, time_str)
# Dump information to file
with open(log_file, 'w') as log:
    log.write(f"MMS model {name} domain adapted on {start_str}\n")
    log.write(f"using model file {model_file}\n")
    log.write(f"Feature extractor loss: {sum_loss_feat}\n")
    log.write(f"Discriminator loss: {sum_loss_disc}\n")
    log.write(f"NDB score: {ndb}\n")
    log.write("Hyperparameters:\n")
    for key in hyperparams_adda.keys():
        log.write(f"{key}\t\t{hyperparams_adda[key]}\n")

mms_classifier = MMSModel(feat_mms, all_classifier)
print(torchinfo.summary(mms_classifier))
torch.save(mms_classifier.state_dict(), file_start+"mms_classifier_statedict.tar")
# save the discriminator
torch.save(discrim.state_dict(), file_start+"discriminator_statedict.tar")              

# trial ended
end = datetime.now(timezone.utc)
print(f"trial execution time (s): {(end-start).total_seconds()}")

Starting Epoch 1
getting MMS file 2021-06-23T01-43-00_2021-06-23T02-18-00.h5, number 1 of 252




components loaded in order ['B', 'E', 'j', 'time']
training on 425291 source samples and 4744 target samples
discrim loss: 5.689586079177939e-06, feat extract loss: 12.388117550885456, sample 550/4741
discrim loss: 5.625089037738527e-07, feat extract loss: 14.51445509694881, sample 1100/4741
discrim loss: 4.170218218072947e-07, feat extract loss: 14.736624571953216, sample 1650/4741
discrim loss: 9.789309519908373e-07, feat extract loss: 13.889805374630013, sample 2200/4741
discrim loss: 0.00016052788553176184, feat extract loss: 9.143181983814333, sample 2750/4741
discrim loss: 0.005434739265688501, feat extract loss: 10.699710239587302, sample 3300/4741
discrim loss: 0.0, feat extract loss: 38.74845349716259, sample 3850/4741
discrim loss: 2.609288882559864e-09, feat extract loss: 19.61133155552098, sample 4400/4741
Bin number 0 found statistically different with z-score 31.26526997403612 > 1.96
Bin number 1 found statistically different with z-score 11.873790816689906 > 1.96
Bin num