# Learned Kernels
This notebook contains all of the code needed to set up and run SMoLK model on the various PPG quality datasets.

## Setup
Here, we import all of the necessary libraries we need to run this code, as well as set the random seed for reproducibility. Note that we set device to "cuda" to run on a CUDA-enabled GPU. If you do not have a CUDA enabled GPU, you can set this to "cpu" instead, though it will be much slower.

In [None]:
import os #path/directory stuff
import pickle

#Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F

#Math
from sklearn.metrics import f1_score, r2_score
from scipy.signal import savgol_filter
import numpy as np
import copy
import random

#Set seed for reproducibility
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

#Progress bar
from tqdm import tqdm

device = "cuda" # device to use

# Load the Data
Here, we load the training and test sets. We use only the DaLiA training set for training, and the DaLiA test set for testing. We load the other datasets to test out-of-distribution performance. These datasets are all included in the GitHub repository.

In [None]:
base = "PPG Data" # Base directory for the PPG data
subdirs = [ # sub dirs that contain each PPG dataset
"new_PPG_DaLiA_test/processed_dataset",
"new_PPG_DaLiA_train/processed_dataset",
"TROIKA_channel_1/processed_dataset",
"WESAD_all/processed_dataset"]

# We use the DaLiA train set exclusively for training
X_train = np.load(os.path.join(base, subdirs[1], "scaled_ppgs.npy"))
Y_train = np.load(os.path.join(base, subdirs[1], "seg_labels.npy"))

# The rest of these datasets are test
DaLiA_X = np.load(os.path.join(base, subdirs[0], "scaled_ppgs.npy"))
DaLiA_Y = np.load(os.path.join(base, subdirs[0], "seg_labels.npy"))

TROIKA_X = np.load(os.path.join(base, subdirs[2], "scaled_ppgs.npy"))
TROIKA_Y = np.load(os.path.join(base, subdirs[2], "seg_labels.npy"))

WESAD_X = np.load(os.path.join(base, subdirs[3], "scaled_ppgs.npy"))
WESAD_Y = np.load(os.path.join(base, subdirs[3], "seg_labels.npy"))

# Define the model
Below is the definition of the Learned Kernels model. It simply consists of three `Conv1d` operations aggregated and summed.

In [None]:
class LearnedFilters(nn.Module):
    def __init__(self, num_kernels=24):
        super(LearnedFilters, self).__init__()
        self.conv1 = nn.Conv1d(1, num_kernels, 192, stride=1, padding="same", bias=True)
        self.conv2 = nn.Conv1d(1, num_kernels, 96, stride=1, padding="same", bias=True)
        self.conv3 = nn.Conv1d(1, num_kernels, 64, stride=1, padding="same", bias=True)
        
        self.w1 = torch.nn.Parameter(torch.zeros(num_kernels), requires_grad=True) #these are learned weights for the kernels        
        self.w2 = torch.nn.Parameter(torch.zeros(num_kernels), requires_grad=True)        
        self.w3 = torch.nn.Parameter(torch.zeros(num_kernels), requires_grad=True)
    

    def forward(self, x):
        batch_size = x.shape[0]
        
        c1 = F.relu(self.conv1(x)) * self.w1[None,:,None]
        c2 = F.relu(self.conv2(x)) * self.w2[None,:,None]
        c3 = F.relu(self.conv3(x)) * self.w3[None,:,None]
        
        aggregate = torch.cat([c1,c2,c3], dim=1)
        aggregate = aggregate.sum(dim=1).view(batch_size, -1)
        aggregate = torch.sigmoid(aggregate)
        
        return aggregate

# Train the models
This should take around a day to train all 90 model folds on a decent consumer GPU.

In [None]:
### Setup
filter_nums = [4, 8, 16, 24, 32, 64, 128, 256, 512] #number of filters to train models for
folds = 10 #number of folds to use for cross validation
epochs = 512 #number of epochs to train for
lr = 0.01 #learning rate
wd = 1e-4 #weight decay
decay_range = [1.0, 0.2] #range of decay values to use for learning rate decay

save_dir = "models" #directory to save models to
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

## Main training loop
Here, we train `folds` number of models for each `filter_nums` number of filters. Keep in mind that the end model will have `filter_num*3` filters, since the `filter_num` refers to the number of filters in each kernel group, for which there are three

In [None]:
# Loop through different numbers of filters
for filter_num in filter_nums:
    # Loop through different versions of the model
    for fold in range(0, folds):
        # Initialize a new instance of the LearnedFilters class with the current number of filters
        net = LearnedFilters(filter_num).to(device)

        # Compute the total number of model parameters and print it
        params = sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, net.parameters())])
        print(f"Training kernel with {filter_num} filters (fold {fold+1})...")
        print("Num params: %i" % params)

        # Initialize the optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=wd)

        # Initialize a linear learning rate scheduler
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=decay_range[0], end_factor=decay_range[1], total_iters=epochs)

        # Initialize a progress bar for visualization
        pbar = tqdm(range(0, epochs))

        # Normalize the input data by subtracting the mean and dividing by the standard deviation
        x = copy.deepcopy(X_train) #we deepcopy the data so we don't modify the original
        
        #normalize each signal
        for i in range(0, len(x)):
            x[i] = (x[i] - np.mean(x[i]))/np.std(x[i])

        # Convert the input and output data to PyTorch tensors and move them to the device
        x = torch.tensor(x, dtype=torch.float32, device=device).reshape(len(x), 1, 1920)
        y = torch.tensor(Y_train, dtype=torch.float32, device=device)

        loss_hist = [] #history of losses (used to update the progress bar primarily)

        # Train the model
        for step in pbar:
            # Zero out the gradients
            optimizer.zero_grad()

            # Initialize the total loss to 0
            total_loss = 0

            # Split the input data into smaller batches if the number of filters is greater than 32 to avoid running out of memory
            split = 32 if filter_num > 32 else 1
            for i in range(0, split):
                split_len = x.shape[0] // split
                out = net(x[i*split_len:(i+1)*split_len]) # Forward pass
                loss = F.binary_cross_entropy(out.view(-1), y[i*split_len:(i+1)*split_len].view(-1)) / split # Compute the loss
                total_loss += loss.item() # Add the loss to the total loss
                loss.backward() # Backward pass

            optimizer.step() # Take an optimizer step
            scheduler.step() # Adjust the learning rate

            loss_hist.append(total_loss) # Add the total loss to the loss history
            pbar.set_description("Loss: %.5f" % np.mean(loss_hist[-256:])) # Update the progress bar with the current loss

        # Save the model
        torch.save(net, os.path.join(save_dir, f"learned_filters_{filter_num}_{fold}.pt"))

# Test the models
In this section, we simply run each model fold again the various test sets

In [None]:
def test_model(net, X, Y):
    """
    Test the given network on the provided data.

    Args:
        net (nn.Module): The trained model.
        X (list): The list of data samples.
        Y (list): The list of corresponding labels.

    Returns:
        float: The DICE score of the model's predictions.
    """
    dtype = net.state_dict()['w1'].dtype # Get the data type of the model
    device = net.state_dict()['w1'].device # Get the device of the model
    data = X
    labels = Y

    preds = [] # To store the model predictions
    true = [] # To store the true labels

    for i in range(len(data)):
        x = data[i]
        x = (x - np.mean(x))/np.std(x) # Normalize the data

        y_pred = net(torch.tensor(x, dtype=dtype, device=device).view(1, -1)) # Predictions

        # Apply a Savitzky-Golay filter to the predictions and convert to binary form
        preds += list(np.float32(savgol_filter(y_pred.detach().cpu().float().numpy()[0], 151, 3) > 0.5)) 
        true += list(labels[i]) # Concatenate labels as list

    return f1_score(preds, true) # Return the F1 score, which in this case, is equivalent to DICE score

def test_suite(net, verbose=True):
    """
    Test the given network on all the datasets and return and/or print the DICE scores

    Args:
        net (nn.Module): The trained model.
        verbose (bool, optional): Whether to print the DICE scores. Defaults to True.

    Returns:
        list: The DICE scores for each dataset.
    """
    results = [0, 0, 0] # Will store the F1 scores for each dataset

    # Test on the DaLiA dataset
    results[0] = test_model(net, DaLiA_X, DaLiA_Y)
    if verbose:
        print("DaLiA DICE score: %.4f" % results[0])

    # Test on the TROIKA dataset
    results[1] = test_model(net, TROIKA_X, TROIKA_Y)
    if verbose:
        print("TROIKA DICE score: %.4f" % results[1])

    # Test on the WESAD dataset
    results[2] = test_model(net, WESAD_X, WESAD_Y)
    
    if verbose:
        print("WESAD DICE score: %.4f" % results[2])
    
    return results

In [None]:
test_results = {} #will store the results of each model

# Loop through different numbers of filters
for filter_num in filter_nums:
    print(f"Testing kernel with {filter_num} filters...")
    # Loop through different versions of the model
    pbar = tqdm(range(0, folds))
    results = []
    for fold in pbar: #loop through all the folds
        # Load the model
        net = torch.load(os.path.join(save_dir, f"learned_filters_{filter_num}_{fold}.pt"), map_location=device)

        # Test the model
        results.append(test_suite(net, verbose=False))

        # Update the progress bar
        pbar.set_description(f"DaLiA: %.4f, TROIKA: %.4f, WESAD: %.4f" % tuple(np.mean(results, axis=0)))
    
    test_results[filter_num] = np.transpose(results) #transpose the results so that the rows are the datasets and the columns are the folds

In [None]:
#save results
with open(os.path.join(save_dir, "test_results.pkl"), "wb") as f:
    pickle.dump(test_results, f)

# Kernel Pruning
This code removes similar kernels from each model and tests the results. We defer to the methods section of our manuscript for a detailed discussion of the pruning process.

In [None]:
def similarity(v1, v2):
    """
    Calculate the cosine similarity between two vectors.

    Args:
        v1, v2 (torch.Tensor): The input vectors.

    Returns:
        float: The cosine similarity between v1 and v2.
    """
    norm_v1 = v1 / v1.norm()
    norm_v2 = v2 / v2.norm()
    
    return (norm_v1*norm_v2).sum().item()

def compute_param_num(num_conv1, num_conv2, num_conv3):
    """
    Compute the number of parameters in a network given the number of kernels in each layer.

    Args:
        num_conv1, num_conv2, num_conv3 (int): The number of kernels in each convolutional layer.

    Returns:
        int: The total number of parameters in the network.
    """
    params = num_conv1*192 +num_conv2*96 + num_conv3*64 #kernel params
    params += num_conv1 + num_conv2 + num_conv3 #biases (1 per kernel)
    params += num_conv1 + num_conv2 + num_conv3 #weights (1 per kernel)
    
    return params

def get_most_similar_kernels(similarity_flat, coords):
    """
    Get the indices of the most similar kernels based on their similarity scores.

    Args:
        similarity_flat (np.array): The flattened array of similarity scores.
        coords (np.array): The flattened array of kernel index pairs.

    Returns:
        np.array: The indices of the most similar kernels.
    """
    return coords[np.argsort(similarity_flat)]


def compute_similarity(state_dict, conv_i, num_kernels):
    """
    Compute the similarity between convolutional kernels for a given layer.

    Args:
        state_dict (dict): The state dict of the network.
        conv_i (int): Index of the convolutional layer.
        num_kernels (int): Number of kernels in each layer.

    Returns:
        tuple: Two numpy arrays containing the flattened similarity scores and their corresponding coordinates.
    """
    coords = []
    similarity_flat = []
    
    # Iterate over all pairs of kernels
    for i in range(num_kernels):
        for j in range(i, num_kernels):
            if i != j:
                sim = similarity(state_dict[f'conv{conv_i}.weight'][i], state_dict[f'conv{conv_i}.weight'][j])
                similarity_flat.append(sim)
                coords.append((j, i))

    return np.asarray(similarity_flat), np.asarray(coords)


def prune(state_dict, conv_i, num_kernels, prune_ratio):
    """
    Prune the least important kernels from a kernel group based on cosine similarity and kernel importance.

    Args:
        state_dict (dict): The state dict of the network.
        conv_i (int): Index of the kernel group.
        num_kernels (int): Number of kernels in each group.
        prune_ratio (float): The proportion of kernels to prune.

    Returns:
        dict: The updated state dict after pruning.
    """
    # Compute similarity of kernels
    sim_flat, coords = compute_similarity(state_dict, conv_i, num_kernels)
    
    # Get the most similar kernels
    most_similar_kernels = get_most_similar_kernels(sim_flat, coords)

    # Prune if the ratio is greater than zero, otherwise do nothing
    if prune_ratio > 0:
        # Iterate over the most similar kernels
        for item in most_similar_kernels[-int(num_kernels*prune_ratio):]:
            # Calculate weights for two kernels under consideration
            item0_weight = state_dict[f'w{conv_i}'][item[0]]*state_dict[f'conv{conv_i}.weight'][item[0]].abs().mean()
            item1_weight = state_dict[f'w{conv_i}'][item[1]]*state_dict[f'conv{conv_i}.weight'][item[1]].abs().mean()

            # Decide which kernel to remove and which to keep
            remove, keep, keep_weight, remove_weight = (item[1], item[0], item0_weight, item1_weight) if item0_weight > item1_weight else (item[0], item[1], item1_weight, item0_weight)
            
            # Update state_dict
            state_dict[f'w{conv_i}'][keep] = (keep_weight + remove_weight) / state_dict[f'conv{conv_i}.weight'][keep].abs().mean()
            state_dict[f'conv{conv_i}.bias'][keep] += state_dict[f'conv{conv_i}.bias'][remove]
            
            # this step is actually enough to "prune the kernel", for computation/measurement sake. In reality, we'd want to remove the kernel from the network for a speedup
            state_dict[f'w{conv_i}'][remove] = 0.0
            state_dict[f'conv{conv_i}.bias'][remove] = 0.0

            # this is an extra step used for counting the kernel that we remove in the end
            # no matter what this value is set to, it will not have any effect on the network since the weight is set to zero
            # however, this value *does* need to be non-zero, since PyTorch handles completely zeroed convolutions a bit weirdly
            # and you'll get weird results convolving a purely 0 kernel
            state_dict[f'conv{conv_i}.weight'][remove] = 1e-5
            
    return state_dict

def prune_network(net, num_kernels, prune_ratio):
    """
    Prune the least important kernels from the model.

    Args:
        net (nn.Module): The model.
        num_kernels (int): Number of kernels in each group.
        prune_ratio (list): The proportion of kernels to prune in each group.

    Returns:
        nn.Module: The pruned network.
    """
    state_dict = net.state_dict()

    # Iterate over all layers and prune
    for conv_i in range(1, 4):
        state_dict = prune(state_dict, conv_i, num_kernels, prune_ratio[conv_i-1])
    net.load_state_dict(state_dict)
    return net


def count_nonzero_weights(state_dict, num_kernels):
    """
    Count the number of non-zero weights in each kernel group of the model.

    Args:
        state_dict (dict): The state dict of the network.
        num_kernels (int): Number of kernels in each layer.

    Returns:
        list: The number of non-zero weights in each layer.
    """
    zero_weights = [0, 0, 0]

    # Iterate over all layers and kernels
    for i in range(1, 4):
        for j in range(0, num_kernels):
            # Count the zero weights
            if (state_dict[f'conv{i}.weight'][j] == 1e-5).all(): #this is the value we set the weights to in the prune function, it is arbitrary
                zero_weights[i-1] += 1
    nonzero_weights = [num_kernels - zero_weights[i] for i in range(3)]

    return nonzero_weights


# Initiate lists to store results
pre_prunes = [] # DICE scores before pruning
post_prunes = [] # DICE scores after pruning
reductions = [] # Reduction in parameters
num_kernels = 128 # Number of kernels in each kernel group

# Iterate over models and prune
for j in range(folds):
    # Load the network
    net = torch.load(f"models/learned_filters_{num_kernels}_{j}.pt", map_location=device)

    print("-------Before pruning-------")
    # Test the network before pruning
    pre_prune = test_suite(net, verbose=True)
    pre_prunes.append(pre_prune)

    # Define the pruning ratio for each layer, essentially, this is the proportion of kernel *pairs* for which one pair will be removed
    # In other words, if the pruning ratio is 1.0, then *at most* half of all the kernels for that layer will be removed
    # However, it is not *guaranteed* that half will be removed, since the similarity ordering can cause some kernels to be the most similar to multiple other kernels
    # thus, this kernel could be removed first, leading to the other pairs to have "already been pruned".
    # This is also why we have to manually compute the number of parameters removed
    prune_ratio = [0.35, 0.0, 0.0]
    
    # Prune the network
    net = prune_network(net, num_kernels, prune_ratio)

    # Count the number of non-zero weights
    nonzero_weights = count_nonzero_weights(net.state_dict(), num_kernels)
    new_kernel_num = nonzero_weights

    # Compute the new parameter count and the reduction
    new_param_count = compute_param_num(new_kernel_num[0], new_kernel_num[1], new_kernel_num[2])
    reduction_percentage = (1 - new_param_count/compute_param_num(num_kernels, num_kernels, num_kernels))*100
    reductions.append(reduction_percentage)

    print(f"\nRemoved {reduction_percentage:.2f}% of params")
    print("-------After pruning-------")

    # Test the network after pruning
    post_prune = test_suite(net, verbose=False)
    post_prunes.append(post_prune)

    # Print the results
    print(f"DaLiA DICE score: {post_prune[0]:.4f} ({post_prune[0]/pre_prune[0]*100:.2f}% of original)")
    print(f"TROIKA DICE score: {post_prune[1]:.4f} ({post_prune[1]/pre_prune[1]*100:.2f}% of original)")
    print(f"WESAD DICE score: {post_prune[2]:.4f} ({post_prune[2]/pre_prune[2]*100:.2f}% of original)")
    print("=====================================")