Example notebook for using the Multitask Neural Decoding code

Import packages

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import disable_possible_user_warnings
from torch.utils.data import DataLoader

Define toy dataset

In [2]:
class M1_EMG_Dataset_Toy(pl.LightningDataModule):
    
    def __init__(self, num_samples, num_neurons, num_muscles, num_modes, batch_size, dataset_type):
        super().__init__()
        self.num_samples = num_samples
        self.num_neurons = num_neurons
        self.num_muscles = num_muscles
        self.num_modes = num_modes
        self.batch_size = batch_size
        self.train_dataset = None
        self.val_dataset = None
        self.dataset_type = dataset_type
        self.decoder_mode1 = torch.randn(self.num_neurons, self.num_muscles)
        self.decoder_mode2 = torch.randn(self.num_neurons, self.num_muscles)

        # Generate toy features and labels
        self.train_dataset = self.generate_toy_dataset(self.num_samples)
        self.val_dataset = self.generate_toy_dataset(self.num_samples*0.2)


    def generate_behavioral(self, num_samples, m1_val, mode):
        """
        Helper function to generate data with behavioral labels for testing ClusterModel sub-model
        
        Input: (int) num_samples: number of samples
            (float) m1_val: value of the features in this mode
            (str) mode: which mode to generate
        Output: ([num_samples, num_neurons] tensor) features: output features
                ([num_samples] tensor) labels: output behavioral labels 
        """
        if mode == "mode1":
            labels = F.one_hot(torch.zeros(num_samples, dtype=int), self.num_modes)
        elif mode == "mode2":
            labels = F.one_hot(torch.ones(num_samples, dtype=int), self.num_modes)

        features = torch.full((num_samples, self.num_neurons), m1_val) + torch.randn(num_samples, self.num_neurons)
            
        return features, labels
    
    
    def generate_emg(self, num_samples, m1_val, decoder):
        """
        Helper function to generate data with EMG labels for testing the full CombinedModel
        
        Input: (int) num_samples: number of samples
            (float) m1_val: value of the features in this mode
        Output: ([num_samples, num_neurons] tensor) features: output features
                ([num_samples, num_muscles] tensor) labels: output EMG labels 
        """

        features = torch.full((num_samples, self.num_neurons), m1_val) + torch.randn(num_samples, self.num_neurons)
        labels = torch.matmul(features, decoder)

        return features, labels


    def generate_toy_dataset(self, num_samples):
        """
        Generates the final toy dataset
        Input: (int) num_samples: number of samples
        Output: ((feature, label) tuple) dataset: output dataset
        """
        
        # Dataset parameters
        num_samples_total = num_samples
        num_samples = int(num_samples_total/2)
        m1_val_mode1 = 10.0
        m1_val_mode2 = 0.0

        # Generate toy dataset with behavioral labels to test out ClusterModel
        if self.dataset_type == "behavioral":
            features_mode1, labels_mode1 = self.generate_behavioral(num_samples, m1_val_mode1, "mode1")
            features_mode2, labels_mode2 = self.generate_behavioral(num_samples, m1_val_mode2, "mode2")
        
        # Generate toy dataset with EMG labels to test out full CombinedModel
        elif self.dataset_type == "emg":
            features_mode1, labels_mode1 = self.generate_emg(num_samples, m1_val_mode1, self.decoder_mode1)
            features_mode2, labels_mode2 = self.generate_emg(num_samples, m1_val_mode2, self.decoder_mode2)
        
        # Format datasets in pairs of (feature, label)
        dataset_mode1 = [(features_mode1[i], labels_mode1[i]) for i in range(len(features_mode1))]
        dataset_mode2 = [(features_mode2[i], labels_mode2[i]) for i in range(len(features_mode2))]
        dataset = dataset_mode1 + dataset_mode2

        return dataset
    
    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        return self.train_dataset[index]

    def collate_fn(self, batch):
        final_batch = {}
        X = []
        Y = []
        for sample in batch:
            X.append(sample[0])
            Y.append(sample[1])
        final_batch["m1"] = torch.stack(X)
        final_batch["emg"] = torch.stack(Y).float()
        return final_batch

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)


Define model

In [3]:
# Clustering model
class ClusterModel(nn.Module):
    """
    input_dim: N
    num_modes: d
    """

    def __init__(self, input_dim, num_modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(input_dim, 1) for i in range(num_modes)])
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        x_d = []
        for linear in self.linears:
            x_d.append(linear(x))
        x = torch.stack(x_d, 2)
        x = self.softmax(x) 
        return x
    
    
# Decoding model
class DecoderModel(nn.Module):
    """
    input_dim: N
    output_dim: M
    num_modes: d
    """
    
    def __init__(self, input_dim, output_dim, num_modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(input_dim, output_dim) for i in range(num_modes)])
    
    def forward(self, x):
        x_d = []
        for linear in self.linears:
            x_d.append(linear(x))
        x = torch.stack(x_d, 2)
        return x


class CombinedModel(nn.Module):
    """
    self.cm stores an instance of the cluster model
    self.dm stores an instancee of the decoding model
    """

    def __init__(self, input_dim, output_dim, num_modes, ev):
        super(CombinedModel, self).__init__()
        self.cm = ClusterModel(input_dim, num_modes)
        self.dm = DecoderModel(input_dim, output_dim, num_modes)
        self.ev = ev

    def forward(self, x):
        x1 = self.cm(x)
        x2 = self.dm(x)
        output = torch.sum(x1 * x2, dim=-1)

        # Return softmax outputs if mode is "eval"
        if self.ev == True:
            return x1
        return output

Define training module

In [4]:
class TrainingModule(LightningModule):

    def __init__(self, model, lr, type):
        super().__init__()
        self.model = model
        self.lr = lr
        self.dataset_type = type

    def forward(self, x):
        return self.model.forward(x)
    
    def training_step(self, batch):
        features = batch["m1"]
        labels = batch["emg"]
        labels_hat = self.model(features).squeeze()
        train_loss = F.mse_loss(labels_hat, labels)
        return train_loss

    def validation_step(self, batch):
        features = batch["m1"]
        labels = batch["emg"]
        labels_hat = self.model(features).squeeze()
        val_loss = F.mse_loss(labels_hat, labels)
        return val_loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

Load dataset and train the model

In [5]:
# Load in dataset
T = 1000
N = 10
M = 3
d = 3  # num_modes
b = 8
type = "emg"
epochs = 500
lr = 0.0001
save_path = "checkpoints"
dataset = M1_EMG_Dataset_Toy(num_samples=T,
                             num_neurons=N,
                             num_muscles=M,
                             num_modes=d,
                             batch_size=b,
                             dataset_type=type)

# Define model
model = CombinedModel(input_dim=N,
                        output_dim=M,
                        num_modes=d, 
                        ev=False)
model = TrainingModule(model=model,
                        lr=lr,
                        type=type)

# Define model checkpoints
filename = "checkpoints/checkpoints.ckpt"
if os.path.exists(filename):
    os.remove(filename)
save_callback = ModelCheckpoint(dirpath = save_path, filename="checkpoints")

# Define trainer
trainer = Trainer(max_epochs=epochs, callbacks=save_callback, enable_progress_bar=False)

# Fit the model
disable_possible_user_warnings()
trainer.fit(model, train_dataloaders=dataset.train_dataloader(), val_dataloaders=dataset.val_dataloader())


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/andrewshen/Desktop/neural_decoding/decoding_venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name  | Type          | Params
----------------------------------------
0 | model | CombinedModel | 132   
----------------------------------------
132       Trainable params
0         Non-trainable params
132       Total params
0.001     Total estimated model params size (MB)
`T

Evaluate trained model

In [6]:
# Access train dataset
train = dataset.train_dataset

# Load trained model
model_path = filename
model = CombinedModel(input_dim=N,
                      output_dim=M,
                      num_modes=d,
                      ev=True)
checkpoint = torch.load(model_path)
state_dict = checkpoint["state_dict"]
model = TrainingModule(model=model,
                       lr=lr,
                       type=type)
model.load_state_dict(state_dict)

# Extract learned cluster model weights
cm_weights = model.model.cm.parameters()

# Extract learned decoder model weights
dm_weights = model.model.dm.parameters()

# Calculate learned cluster labels
"""
Output learned clusters will be of shape (num_samples, num_modes)
Ex: [0.0000, 0.0017, 0.9983] for num_modes=3 means the model learned that using 99.83% of mode 3 resulted in best performance
"""
cluster_probs = []
for sample in train:
    x = sample[0].unsqueeze(0)
    curr_probs = model.forward(x)
    cluster_probs.append(curr_probs.squeeze())
cluster_probs = torch.stack(cluster_probs).detach()
torch.set_printoptions(sci_mode=False)
print("Learned cluster probabilities:\n%s" % cluster_probs)

# Calculate final predictions
"""
Output predictions will be of shape (num_samples, num_muscles)

"""
model.model.ev = False
final_preds = []
for sample in train:
    x = sample[0].unsqueeze(0)
    pred = model.forward(x)
    final_preds.append(pred.squeeze())
final_preds = torch.stack(final_preds).detach()
print("\nFinal predictions:\n%s" % final_preds)

# Compare final predictions to output EMG data
train_emgs = torch.stack([sample[1] for sample in train])
print("\nEMG values:\n%s" % train_emgs)

Learned cluster probabilities:
tensor([[    1.0000,     0.0000,     0.0000],
        [    1.0000,     0.0000,     0.0000],
        [    1.0000,     0.0000,     0.0000],
        ...,
        [    0.0153,     0.4129,     0.5718],
        [    0.0138,     0.1447,     0.8415],
        [    0.0333,     0.4064,     0.5604]])

Final predictions:
tensor([[    -0.9288,    -27.6711,      6.1406],
        [     0.7430,    -28.4463,      5.9731],
        [    -0.0005,    -26.6970,      6.7494],
        ...,
        [     1.0318,      0.6768,     -0.7112],
        [     1.0472,      7.9783,      0.7533],
        [     0.7715,      3.1245,      1.0610]])

EMG values:
tensor([[ -5.4241, -25.4261,   7.1858],
        [ -0.6420, -29.8326,   5.9900],
        [ -1.9857, -24.4516,   7.1309],
        ...,
        [  1.0294,   0.6697,  -0.7132],
        [  1.0416,   7.9853,   0.7433],
        [  0.7316,   3.1858,   1.0761]])
