### Resnet-based model for funsies
This one is just going to use the magnetic field for now until I decide how to add multiple images instead of one image

In [1]:
import torch
from torch import nn
import torchvision as tv
from torch.utils.data import TensorDataset, DataLoader
from  torch.nn.functional import one_hot
import h5py
from sklearn.utils import shuffle
import sys
import datetime as dt

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

dtype = torch.double
   
# Get functions from other notebooks
%run /tigress/kendrab/analysis-notebooks/preproc_utils.ipynb
%run /tigress/kendrab/analysis-notebooks/eval_utils.ipynb

start = dt.datetime.now(dt.timezone.utc)  # for timing
time_str = start.strftime("%H%M%S")
date_str = start.strftime("%d-%m-%y")
start_str = date_str + time_str

Using cuda device


### Make the model

In [2]:
model_name = "B"

# hyperparameters
padding_length = 10  # amount of data on each side of each segment for additional info
stride = 10  # size (and therefore spacing) of each segment
input_length = stride + 2*padding_length
thinning_factor = [0.8, None]
learning_rate = 0.01
epochs = 4
hyperparams = {'learning_rate':learning_rate, 'input_length':input_length, 'stride':stride, 'epochs':epochs, 'thinning_factor':thinning_factor}

# other parameters
batch_size = 64  # idk what this should be for best performance 

In [3]:
model = tv.models.resnet50(weights="IMAGENET1K_V2")
# Replace last layer to give the right output shape
model.fc = nn.Sequential(nn.Linear(in_features=2048, out_features=2*stride, bias=True), 
                         nn.Unflatten(1,(2,stride)))
# to GPU
model = model.to(device=device, dtype=torch.double)
# freeze everything except this last layer
# TRY DIFFERENT LAYERS FROZEN/UNFROZEN
# model.requires_grad_(False)
# model.fc.requires_grad_(True)

### Define the training and testing loops

In [4]:
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)  # the length of a tensordataset is the batch size (shared first dim)
    for batch, (_, _, B, _, _, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(B)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()

        if batch % 100 == 0:
            current_sample = (batch+1)*B.shape[0]
            print(f"mean loss: {loss}, sample {current_sample}/{size}")

In [5]:
def test_loop(dataloader, model, loss_fn):
    model.eval()
    pred_list = []
    size = len(dataloader.dataset)  # number of samples
    tot_points = size*stride
    num_batches = len(dataloader)
    test_loss_sum, correct = 0, 0

    with torch.no_grad():
        for _, _, B, _, _, y in dataloader:
            pred = model(B)
            pred_list.append(pred.cpu().numpy())
            test_loss_sum += loss_fn(pred, y).item()  # .item() fetches the python scalar
            # number of correct per-point predictions
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
    ##### MAKE SURE DIAGNOSTICS ARE CALCULATED CORRECTLY
    tot_pred = np.concatenate(pred_list, axis=0)
    test_loss_sum /= num_batches
    correct /= tot_points
    print(f"Test Error: \n Accuracy: {(100*correct):>0.5f}%, Avg loss: {test_loss_sum:>8f} \n")
    return tot_pred

### Load and preprocess the data

In [6]:
# TODO use command line args or someting easier than throwing it here
basedir = '/tigress/kendrab/21032023/'
readpaths = []

for i in range(10):
    totdir = basedir+str(i)+'/'
    for j in range(5,60,5):
        readpaths.append(totdir+f"100samples_idx{j}_bxbybzjyvzexeyez.hdf5")
        
idx_list = []  # to keep track of which file what sample came from
s_list = []
bx_list = []
by_list = []
bz_list = []
x0_list = []
x1_list = []
topo_list = []

train_idx = None

for idx, filepath in enumerate(readpaths):
    with h5py.File(filepath, 'r') as file:
        idx_list += [np.array([idx for i in bx]) for bx in file['bx_mms_smooth'][:]]  # check this structure!!!
        s_list += list(file['s'][:])
        bx_list += list(file['bx_mms_smooth'][:])
        by_list += list(file['by_mms'][:])
        bz_list += list(file['bz_mms_smooth'][:])
        x0_list += list(file['x_mms'][:])
        x1_list += list(file['z_mms'][:])
        topo_list_tmp = list(file['topo'][:])
        for i in range(len(topo_list_tmp)):  # I tried to vectorize this but I didn't get it to work
            topo_list_tmp[i] = torch.from_numpy(topo_list_tmp[i].astype(int) % 2)  # cat 0,2 are not plasmoids, cat 1,3 are
            topo_list_tmp[i] = one_hot(topo_list_tmp[i], num_classes=2)
        topo_list += topo_list_tmp
        
        if idx == int(.7*len(readpaths)):  # roughly 70-30 train-test split for now
            train_idx = len(bx_list)

print(len(bx_list))
# do train test split
idx_train_list = idx_list[:train_idx]  # to keep track of which file what sample came from
s_train_list = s_list[:train_idx] 
bx_train_list = bx_list[:train_idx] 
by_train_list = by_list[:train_idx] 
bz_train_list = bz_list[:train_idx] 
x0_train_list = x0_list[:train_idx] 
x1_train_list = x1_list[:train_idx] 
topo_train_list = topo_list[:train_idx] 

idx_test_list = idx_list[train_idx:] 
s_test_list = s_list[train_idx:] 
bx_test_list = bx_list[train_idx:] 
by_test_list = by_list[train_idx:] 
bz_test_list = bz_list[train_idx:] 
x0_test_list = x0_list[train_idx:] 
x1_test_list = x1_list[train_idx:] 
topo_test_list = topo_list[train_idx:] 

# BUT WAIT THERE'S MORE! Include the slices from plain ol current sheets. Split 50-50 between train and test
# lots of magic numbers here but we don't have time to make the code nice rn
noplasmoids_dir = '/tigress/kendrab/06022023/'
noplasmoids_paths = []

for j in range(5,55,5):
        noplasmoids_paths.append(noplasmoids_dir+f"100samples_idx{j}_bxbybzjyvzexeyez.hdf5")
        
for k in range(5):
    # training part
    with h5py.File(noplasmoids_paths[k], 'r') as file:
        idx_train_list += [np.array([idx for i in bx]) for bx in file['bx_mms_smooth'][:]]  # check this structure!!!
        s_train_list += list(file['s'][:])
        bx_train_list += list(file['bx_mms_smooth'][:])
        by_train_list += list(file['by_mms'][:])
        bz_train_list += list(file['bz_mms_smooth'][:])
        x0_train_list += list(file['x_mms'][:])
        x1_train_list += list(file['z_mms'][:])
        topo_list_tmp = list(file['topo'][:])
        for i in range(len(topo_list_tmp)):  # I tried to vectorize this but I didn't get it to work
            topo_list_tmp[i] = torch.from_numpy(topo_list_tmp[i].astype(int) % 2)  # cat 0,2 are not plasmoids, cat 1,3 are
            topo_list_tmp[i] = one_hot(topo_list_tmp[i], num_classes=2)
        topo_train_list += topo_list_tmp    
        
    # testing part
    with h5py.File(noplasmoids_paths[k+5], 'r') as file:
        idx_test_list += [np.array([idx for i in bx]) for bx in file['bx_mms_smooth'][:]]  # check this structure!!!
        s_test_list += list(file['s'][:])
        bx_test_list += list(file['bx_mms_smooth'][:])
        by_test_list += list(file['by_mms'][:])
        bz_test_list += list(file['bz_mms_smooth'][:])
        x0_test_list += list(file['x_mms'][:])
        x1_test_list += list(file['z_mms'][:])
        topo_list_tmp = list(file['topo'][:])
        for i in range(len(topo_list_tmp)):  # I tried to vectorize this but I didn't get it to work
            topo_list_tmp[i] = torch.from_numpy(topo_list_tmp[i].astype(int) % 2)  # cat 0,2 are not plasmoids, cat 1,3 are
            topo_list_tmp[i] = one_hot(topo_list_tmp[i], num_classes=2)
        topo_test_list += topo_list_tmp        

11000


### Train test split

In [7]:
# chunk into sliding windows
# NOTE TOPO HAS DIFFERENT SEGMENT LENGTHS THAN THE INPUTS (stride vs. 2*padding+stride)
idx_train = batch_unpadded_subsects(idx_train_list, padding_length, stride)
s_train = batch_subsects(s_train_list, input_length, stride)  # not going through training so don't need to shape right
bx_train = np.expand_dims(batch_subsects(bx_train_list, input_length, stride),1)
by_train = np.expand_dims(batch_subsects(by_train_list, input_length, stride),1)
bz_train = np.expand_dims(batch_subsects(bz_train_list, input_length, stride),1)
x0_train = batch_unpadded_subsects(x0_train_list, padding_length, stride)
x1_train = batch_unpadded_subsects(x1_train_list, padding_length, stride)
topo_train = np.swapaxes(batch_unpadded_subsects(topo_train_list, padding_length, stride), 1, 2)

idx_test = batch_unpadded_subsects(idx_test_list, padding_length, stride)
s_test = np.expand_dims(batch_subsects(s_test_list, input_length, stride),1)
bx_test = np.expand_dims(batch_subsects(bx_test_list, input_length, stride),1)
by_test = np.expand_dims(batch_subsects(by_test_list, input_length, stride),1)
bz_test = np.expand_dims(batch_subsects(bz_test_list, input_length, stride),1)
x0_test = batch_unpadded_subsects(x0_test_list, padding_length, stride)
x1_test = batch_unpadded_subsects(x1_test_list, padding_length, stride)
topo_test = np.swapaxes(batch_unpadded_subsects(topo_test_list, padding_length, stride), 1, 2)

# shuffle the segments so they aren't adjacent to overlapping/similar segments
idx_train, s_train, bx_train, by_train, bz_train, x0_train, x1_train, topo_train = \
    shuffle(idx_train, s_train, bx_train, by_train, bz_train, x0_train, x1_train, topo_train)

idx_test, s_test, bx_test, by_test, bz_test, x0_test, x1_test, topo_test = \
    shuffle(idx_test, s_test, bx_test, by_test, bz_test, x0_test, x1_test, topo_test)

# try to do some rebalancing in the training set
# model is struggling on plasmoids, which are underrepresented
[idx_train, s_train, bx_train, by_train, bz_train, x0_train, x1_train], topo_train = \
    rebalance_ctrl_group([idx_train, s_train, bx_train, by_train, bz_train, x0_train, x1_train],
                         topo_train, null_label=[1,0], thinning_factor = thinning_factor[0])


Total batch: 369510
Number of null samples: 296973
Number of non-null samples: 72537
With thinning factor 0.8 will remove 237578 null samples


### Transform Bx,y,z to one 2d 3-channel "image"

In [None]:
# This is so slow and I hate it but that is how it must be :(
# End up with shape (batch, 3, input_length, input_length)
img_shape_train = (bx_train.shape[0], 3, input_length, input_length)
b_img_train = np.zeros(img_shape_train)
img_shape_test = (bx_test.shape[0], 3, input_length, input_length)
b_img_test = np.zeros(img_shape_test)

# calculate range over last dim for the transformation
bx_train_r = np.ptp(bx_train, axis=-1)
by_train_r = np.ptp(by_train, axis=-1)
bz_train_r = np.ptp(bz_train, axis=-1)
bx_test_r = np.ptp(bx_test, axis=-1)
by_test_r = np.ptp(by_test, axis=-1)
bz_test_r = np.ptp(bz_test, axis=-1)

for i in range(input_length):  # lets only do this garbage for loop once
    for j in range(input_length):
        # bx -> 0, by -> 1, bz -> 2
        # train array
        b_img_train[:,0,i,j] = (bx_train[:,:,i]*(1 - np.abs(bx_train[:,:,i] - bx_train[:,:,j])
                                                 /bx_train_r)).squeeze()
        b_img_train[:,1,i,j] = (by_train[:,:,i]*(1 - np.abs(by_train[:,:,i] - by_train[:,:,j])
                                                 /by_train_r)).squeeze()
        b_img_train[:,2,i,j] = (bz_train[:,:,i]*(1 - np.abs(bz_train[:,:,i] - bz_train[:,:,j])
                                                 /bz_train_r)).squeeze()
        # test array
        b_img_test[:,0,i,j] = (bx_test[:,:,i]*(1 - np.abs(bx_test[:,:,i] - bx_test[:,:,j])
                                                 /bx_test_r)).squeeze()
        b_img_test[:,1,i,j] = (by_test[:,:,i]*(1 - np.abs(by_test[:,:,i] - by_test[:,:,j])
                                                 /by_test_r)).squeeze()
        b_img_test[:,2,i,j] = (bz_test[:,:,i]*(1 - np.abs(bz_test[:,:,i] - bz_test[:,:,j])
                                                 /bz_test_r)).squeeze()        

### Move to GPU and make datasets

In [None]:
# numpy arrays to torch tensors (while crying about how many lines of code this is surely there is a better way)
idx_train = torch.from_numpy(idx_train).to(device, dtype=dtype)
s_train = torch.from_numpy(s_train).to(device, dtype=dtype)
b_img_train = torch.from_numpy(b_img_train).to(device, dtype=dtype)
x0_train = torch.from_numpy(x0_train).to(device, dtype=dtype)
x1_train = torch.from_numpy(x1_train).to(device, dtype=dtype)
topo_train = torch.from_numpy(topo_train).to(device, dtype=dtype)

idx_test = torch.from_numpy(idx_test).to(device, dtype=dtype)
s_test = torch.from_numpy(s_test).to(device, dtype=dtype)
b_img_test = torch.from_numpy(b_img_test).to(device, dtype=dtype)
x0_test = torch.from_numpy(x0_test).to(device, dtype=dtype)
x1_test = torch.from_numpy(x1_test).to(device, dtype=dtype)
topo_test = torch.from_numpy(topo_test).to(device, dtype=dtype)

In [None]:
# collect data into Datasets
train_dset = TensorDataset(idx_train, s_train, b_img_train,
                              x0_train, x1_train, topo_train)
test_dset =  TensorDataset(idx_test, s_test, b_img_test,
                              x0_test, x1_test, topo_test)
# Make DataLoaders for the training and test data
train_dl = DataLoader(train_dset, batch_size = batch_size)
test_dl = DataLoader(test_dset, batch_size = batch_size)

### Compile and train model

In [None]:
loss_fn = nn.CrossEntropyLoss(reduction='mean')  # We are allowing reduction bc backward() 
                                            # needs a scalar or to specify a different gradient.
                                                # Easier this way. Probably.
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dl, model, loss_fn, opt)
    test_loop(test_dl, model, loss_fn)

### Make output directories if they do not exist and set up output file names

In [None]:
log_file, cf_file, samplefile_start = generic_outputs_structure("/tigress/kendrab/analysis-notebooks/model_outs/",
                                                                model_name, date_str, time_str)

### Save model

In [None]:
torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss_fn': loss_fn}, samplefile_start+"_modelfile.tar")

## To load:
# model = TheModelClass(*args, **kwargs)
# optimizer = TheOptimizerClass(*args, **kwargs)

# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss_fn = checkpoint['loss_fn']

# model.eval()
# # - or -
# model.train()


### Observe the results, dump information to file

In [None]:
model.eval()

with torch.no_grad():
    with open(log_file, 'w') as log:
        log.write(f"Model {model_name} trained on {start_str}\n")
        log.write(f"loss function \t\t{loss_fn.__repr__()}\n")
        log.write("Hyperparameters:\n")
        for key in hyperparams.keys():
            log.write(f"{key}\t\t{hyperparams[key]}\n")

        log.write("Training performance\n")        
        print("Training performance")
        train_topo_pred = test_loop(train_dl, model, loss_fn)   
        train_1d = np.argmax(topo_train.cpu().numpy(), axis=1).flatten() # for confusion matrix
        train_1d_pred = np.argmax(train_topo_pred, axis=1).flatten()  
        num_per_cat = [np.sum(topo_train.cpu().numpy()[:,i,:] == 1) for i in range(2)]
        log.write(f"cat_breakdown\t\t{num_per_cat}\n")
        print(f"cat_breakdown\t\t{num_per_cat}")
        # TODO CALCULATE RECALL PER CATEGORY

        log.write("Testing performance\n")
        print("Testing performance")
        test_topo_pred = test_loop(test_dl, model, loss_fn)
        test_1d = np.argmax(topo_test.cpu().numpy(), axis=1).flatten() # for confusion matrix
        test_1d_pred = np.argmax(test_topo_pred, axis=1).flatten()  
        num_per_cat = [np.sum(topo_test.cpu().numpy()[:,i,:] == 1) for i in range(2)]
        log.write(f"cat_breakdown\t\t{num_per_cat}\n")
        print(f"cat_breakdown\t\t{num_per_cat}")

        end = dt.datetime.now(dt.timezone.utc)    
        log.write(f"runtime_seconds\t\t{(end-start).total_seconds()}")

### Plot confusion matrices

In [None]:
plt_traintest_cf_matrices(train_1d, train_1d_pred, test_1d, test_1d_pred, cf_file)