In [1]:
import torchinfo
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torcheval.metrics as tem  # YAY METRICS
import numpy as np
import h5py
from sklearn.utils import shuffle
import copy
from datetime import datetime
import sys
import optuna

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
dtype = torch.double

# Get functions from other notebooks
%run /tigress/kendrab/analysis-notebooks/preproc_utils.ipynb
%run /tigress/kendrab/analysis-notebooks/eval_utils.ipynb

Using cpu device


In [2]:
# load a model
%run /tigress/kendrab/analysis-notebooks/torch_models/import_model.ipynb



In [3]:
file_start = "/tigress/kendrab/analysis-notebooks/model_outs/scratchwork/14-06-24F143444_"

### Functions

In [4]:
def test_loop(dataloader, model, loss_fn, threshold = 0.5):
    # convert threshold from probability to logit
    threshold_logit = np.log(threshold/(1-threshold))
    model.eval()
    pred_list = []
    size = len(dataloader.dataset)  # number of samples
    stride = dataloader.dataset[0][0].shape[-1]
    tot_points = size*stride
    num_batches = len(dataloader)
    test_loss_sum, correct = 0, 0

    with torch.no_grad():
        for _, _, bx, by, bz, ex, ey, ez, jy, _, _, y in dataloader:
            pred = model(bx, by, bz, ex, ey, ez, jy)
            pred_list.append(pred.cpu().numpy())
            test_loss_sum += loss_fn(pred, y).item()  # .item() fetches the python scalar
            # number of correct per-point predictions
            correct += ((pred > threshold_logit) == y).type(torch.float).sum().item()
    tot_pred = np.concatenate(pred_list, axis=0)
    test_loss_sum /= num_batches
    correct /= tot_points
    print(f"Test Error: \n Accuracy: {(100*correct):>0.5f}%, Avg loss: {test_loss_sum:>8f} \n")    
    return tot_pred

### Import the test dataset

In [5]:
''' LOAD AND PREPROCESS THE DATA'''
basedir = '/tigress/kendrab/21032023/'
readpaths = []
test_num = 9

totdir = basedir+str(test_num)+'/'+'new_better/'
for j in range(5,60,5):
    readpaths.append(totdir+f"100samples_idx{j}_bxbybzjyvzexeyez.hdf5")

In [6]:
idx_list = []  # to keep track of which file what sample came from
s_list = []
bx_list = []
by_list = []
bz_list = []
ex_list = []
ey_list = []
ez_list = []
jy_list = []
x0_list = []
x1_list = []
topo_list = []

for idx, filepath in enumerate(readpaths):
    with h5py.File(filepath, 'r') as file:
        idx_list += [np.array([idx for i in bx]) for bx in file['bx_mms_smooth'][:]]  # check this structure!!!
        s_list += list(file['s'][:])
        bx_list += list(file['bx_mms_smooth'][:])
        by_list += list(file['by_mms'][:])
        bz_list += list(file['bz_mms_smooth'][:])
        ex_list += list(file['ex_mms'][:]) 
        ey_list += list(file['ey_mms'][:])
        ez_list += list(file['ez_mms'][:])  # vx_mms is simulation vz  thus the filename 
        jy_list += list(file['jy_mms'][:])
        x0_list += list(file['x_mms'][:])
        x1_list += list(file['z_mms'][:])
        topo_list_tmp = list(file['topo'][:])
        for i in range(len(topo_list_tmp)):  # I tried to vectorize this but I didn't get it to work
            topo_list_tmp[i] = torch.from_numpy(topo_list_tmp[i].astype(int) % 2)  # cat 0,2 are not plasmoids, cat 1,3 are
        topo_list += topo_list_tmp

noplasmoids_dir = '/tigress/kendrab/06022023/'
noplasmoids_path = noplasmoids_dir+'new_better/'+f"100samples_idx50_bxbybzjyvzexeyez.hdf5"  # last file for test dataset

with h5py.File(noplasmoids_path, 'r') as file:
    idx_list += [np.array([idx for i in bx]) for bx in file['bx_mms_smooth'][:]]  # check this structure!!!
    s_list += list(file['s'][:])
    bx_list += list(file['bx_mms_smooth'][:])
    by_list += list(file['by_mms'][:])
    bz_list += list(file['bz_mms_smooth'][:])
    ex_list += list(file['ex_mms'][:]) 
    ey_list += list(file['ey_mms'][:])
    ez_list += list(file['ez_mms'][:])
    jy_list += list(file['jy_mms'][:])
    x0_list += list(file['x_mms'][:])
    x1_list += list(file['z_mms'][:])
    topo_list_tmp = list(file['topo'][:])
    for i in range(len(topo_list_tmp)):  # I tried to vectorize this but I didn't get it to work
        topo_list_tmp[i] = torch.from_numpy(topo_list_tmp[i].astype(int) % 2)  # cat 0,2 are not plasmoids, cat 1,3 are
    topo_list += topo_list_tmp           


In [7]:
'''CHONK UP THE DATA'''
# chunk into sliding windows
# NOTE TOPO HAS DIFFERENT SEGMENT LENGTHS THAN THE INPUTS (stride vs. 2*padding+stride)
idx = batch_unpadded_subsects(idx_list, padding_length, stride)
s = batch_subsects(s_list, input_length, stride)  # not going through training so don't need to shape right
bx = np.expand_dims(batch_subsects(bx_list, input_length, stride),1)
by = np.expand_dims(batch_subsects(by_list, input_length, stride),1)
bz = np.expand_dims(batch_subsects(bz_list, input_length, stride),1)
ex = np.expand_dims(batch_subsects(ex_list, input_length, stride),1)
ey = np.expand_dims(batch_subsects(ey_list, input_length, stride),1)
ez = np.expand_dims(batch_subsects(ez_list, input_length, stride),1)
jy = np.expand_dims(batch_subsects(jy_list, input_length, stride),1)
x0 = batch_unpadded_subsects(x0_list, padding_length, stride)
x1 = batch_unpadded_subsects(x1_list, padding_length, stride)
topo = batch_unpadded_subsects(topo_list, padding_length, stride)

# shuffle the segments so they aren't adjacent to overlapping/similar segments
idx, s, bx, by, bz, ex, ey, ez, jy, x0, x1, topo = \
    shuffle(idx, s, bx, by, bz, ex, ey, ez, jy, x0, x1, topo)

In [8]:
# numpy arrays to torch tensors (while crying about how many lines of code this is surely there is a better way)
idx = torch.from_numpy(idx).to(device, dtype=dtype)
s = torch.from_numpy(s).to(device, dtype=dtype)
bx = torch.from_numpy(bx).to(device, dtype=dtype)
by = torch.from_numpy(by).to(device, dtype=dtype)
bz = torch.from_numpy(bz).to(device, dtype=dtype)
ex = torch.from_numpy(ex).to(device, dtype=dtype)
ey = torch.from_numpy(ey).to(device, dtype=dtype)
ez = torch.from_numpy(ez).to(device, dtype=dtype)
jy = torch.from_numpy(jy).to(device, dtype=dtype)
x0 = torch.from_numpy(x0).to(device, dtype=dtype)
x1 = torch.from_numpy(x1).to(device, dtype=dtype)
topo = torch.from_numpy(topo).to(device, dtype=dtype)

In [9]:
# collect data into Dataset
test_dset = TensorDataset(idx, s, bx, by, bz, ex, ey, ez,
                              jy, x0, x1, topo)
# Make DataLoaders for the training and test data
test_dl = DataLoader(test_dset, batch_size = batch_size)  # batch_size from import_model

### Test the model

In [13]:
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')  # We are allowing reduction bc backward() 
                                            # needs a scalar or to specify a different gradient.
                                                # Easier this way. Probably.
opt = torch.optim.Adam(model.parameters(),lr=learning_rate)
metric = tem.BinaryAUPRC(device=device)

print("Testing performance")
test_1d_pred = test_loop(test_dl, model, loss_fn).flatten()
test_1d = topo.cpu().numpy().flatten() # for confusion matrix 
num_per_cat = [np.sum(test_1d == i) for i in range(2)]
print(f"cat_breakdown\t\t{num_per_cat}")
test_1d_probs = torch.nn.functional.sigmoid(torch.from_numpy(test_1d_pred))  # logit to probability
metric.update(test_1d_probs, torch.from_numpy(test_1d))

# Plot confusion matrices
test_1d_pred_int = test_1d_pred > 0 

fig, ax = plt.subplots(3, figsize=(6,10))
ax[0].set(title="Testing Confusion, non-normalized")
ax[1].set(title="Testing Confusion, normalized true")
ax[2].set(title="Testing Confusion, normalized pred")
cf_test = ConfusionMatrixDisplay(confusion_matrix(test_1d, test_1d_pred_int))
cf_test_t = ConfusionMatrixDisplay(confusion_matrix(test_1d, test_1d_pred_int, normalize='true'))
cf_test_p = ConfusionMatrixDisplay(confusion_matrix(test_1d, test_1d_pred_int, normalize='pred'))
cf_test.plot(ax=ax[0], cmap='Greys', colorbar=False)
cf_test_t.plot(ax=ax[1], cmap='Greys', colorbar=False)
cf_test_p.plot(ax=ax[2], cmap='Greys', colorbar=False)

fig.tight_layout()
fig.savefig(file_start+"confusion"+".pdf")
plt.close(fig='all')
  
# plot precision-recall curve for test.
test_1d_prob = torch.nn.functional.sigmoid(torch.Tensor(test_1d_pred).to(device))  # to probability rather than logit to see if torcheval really supports logits for this function (it doesn't seem like they do)
precision, recall, thresh = tem.functional.binary_binned_precision_recall_curve(test_1d_prob.to(device),
                                                                                torch.Tensor(test_1d).to(device), threshold=100)
# last point for thresh = 1 is repeated, so we delete it
plot_prc(precision[:-1], recall[:-1], thresh, file_start+"prc_test", title="Precision-Recall curve, test dataset")
print(f"AUPRC: {metric.compute()}")

Testing performance
Test Error: 
 Accuracy: 66.23090%, Avg loss: 0.599586 

cat_breakdown		[358877, 114849]
torch.Size([100])
AUPRC: 0.5629443526268005


In [None]:
# find number of completed trials
study = optuna.load_study(study_name='model_f_study',storage="mysql+mysqldb://optunauser:Frikkenoptuna@stellar-intel.princeton.edu:47793/model_f",
                          pruner=optuna.pruners.HyperbandPruner())

comp_trials = 0

for trial in study.trials:
    if trial.state == 1:  # complete
        comp_trials += 1
        
print(comp_trials)