Used to calculate per-sample model prediction scores, per-taxa attribution values (used for interpretability), as well as per-taxa averaged embeddings (used for plotting the taxa). Note the current file is set to compute attributions only for IBD, but can easily be changed for Schirmer/HMP2 and Halfvarson.

In [11]:
!git clone https://github.com/rsvarma/microbiome_transformers.git
!cp microbiome_transformers/finetune_discriminator/dataset.py .

fatal: destination path 'microbiome_transformers' already exists and is not an empty directory.


In [12]:
!pip install transformers --quiet
import torch
import torch.nn as nn
import numpy as np
np.float = float
import tqdm
from dataset import ELECTRADataset
from transformers import ElectraConfig,ElectraForSequenceClassification
from torch.utils.data import DataLoader
from transformers.activations import get_activation

device = "cuda:0"



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [13]:
class ElectraDiscriminator(nn.Module):
    """
    A custom ELECTRA-based discriminator model for sequence classification tasks.

    This model combines an embedding layer with an ELECTRA-based sequence classifier.
    It can be initialized with pre-trained weights or from scratch.
    """

    def __init__(self,config:ElectraConfig,embeddings,discriminator = None, embed_layer = None):
        """
        Initialize the ElectraDiscriminator model.

        Args:
            config (ElectraConfig): Configuration for the ELECTRA model.
            embeddings (torch.Tensor): Pre-trained embeddings.
            discriminator (str, optional): Path to pre-trained discriminator weights.
            embed_layer (str, optional): Path to pre-trained embedding layer weights.
        """
        super().__init__()
        self.embed_layer = nn.Embedding(num_embeddings=config.vocab_size,embedding_dim=config.embedding_size,padding_idx = config.vocab_size-1)
        if embed_layer:
            self.embed_layer.load_state_dict(torch.load(embed_layer))
        else:
            self.embed_layer.weight = nn.Parameter(embeddings)
        if discriminator:
            self.discriminator = ElectraForSequenceClassification.from_pretrained(discriminator,config=config)
        else:
            self.discriminator = ElectraForSequenceClassification(config)
        self.softmax = nn.Softmax(1)

    def forward(self,data,attention_mask,labels):
        """
        Forward pass of the ElectraDiscriminator model.

        Args:
            data (torch.Tensor): Input tensor of token ids.
            attention_mask (torch.Tensor): Attention mask for input sequence.
            labels (torch.Tensor): Ground truth labels.

        Returns:
            tuple: Contains:
                - loss (torch.Tensor): The classification loss.
                - scores (torch.Tensor): Softmax probabilities for each class.
                - last_hidden (torch.Tensor): Last hidden state of the model.
        """
        data = self.embed_layer(data)
        output = self.discriminator(attention_mask=attention_mask,inputs_embeds=data,labels=labels, output_hidden_states=True)
        scores = self.softmax(output['logits'])
        loss = output['loss']
        last_hidden = output['hidden_states'][-1]
        return loss, scores, last_hidden

In [14]:
class ElectraClassificationHead(nn.Module):
    """
    Head for sentence-level classification tasks using ELECTRA.

    This module implements a classification head that can be used on top of
    ELECTRA's hidden states for sequence classification tasks.
    """
    def __init__(self, config):
        """
        Initialize the ElectraClassificationHead model.

        Args:
            config (ElectraConfig): Configuration for the ELECTRA model.
        """
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
        self.softmax = nn.Softmax(1)

    def forward(self, features, **kwargs):
        """
        Forward pass of the classification head.

        Args:
            features (torch.Tensor): The input features from the ELECTRA model.
                                     Expected shape: (batch_size, sequence_length, hidden_size)
            **kwargs: Additional keyword arguments (not used in this implementation).

        Returns:
            torch.Tensor: The softmax probabilities for each class.
                          Shape: (batch_size, num_labels)
        """
        #print(features.size())
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        #print(x.size())
        x = self.dropout(x)
        x = self.dense(x)
        x = get_activation("gelu")(x)  # although BERT uses tanh here, it seems Electra authors used gelu here
        x = self.dropout(x)
        x = self.out_proj(x)
        #print(x)
        x = self.softmax(x)
        #print(x)
        return x

In [15]:
class ElectraEnsembleHead(nn.Module):
    """
    Ensemble head for ELECTRA-based models, combining multiple classification heads.

    This module creates an ensemble of ElectraClassificationHead instances and
    provides methods to set their parameters and perform forward passes.
    """    
    def __init__(self, config, num_ffs = 10, device = 'cuda:0'):
        """
        Initialize the ElectraEnsembleHead.

        Args:
            config (ElectraConfig): Configuration for the ELECTRA model.
            num_ffs (int): Number of feed-forward networks in the ensemble.
            device (str): Device to use for computations.
        """
        super().__init__()
        self.num_ffs = num_ffs
        self.device = device
        self.ffs = nn.ModuleList([ElectraClassificationHead(config) for i in range(self.num_ffs)])

    def set_ff(self, num, params):
        """
        Set the parameters for a specific feed-forward network in the ensemble.

        Args:
            num (int): Index of the feed-forward network to update.
            params (list): List of parameter tensors to set.
        """
        self.ffs[num].dense.weight.data = params[0].to(self.device)
        self.ffs[num].dense.bias.data = params[1].to(self.device)
        self.ffs[num].out_proj.weight.data = params[2].to(self.device)
        self.ffs[num].out_proj.bias.data = params[3].to(self.device)

    def forward(self, x):
        """
        Perform a forward pass through the ensemble head.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            tuple: Contains:
                - x (torch.Tensor): Averaged output across all feed-forward networks.
                - pre_avg (list): List of individual outputs from each feed-forward network.
        """
        pre_avg = [ff(x).detach() for ff in self.ffs]

        x = sum([ff(x) for ff in self.ffs]) / self.num_ffs
        return x, pre_avg


class ElectraEnsemble(nn.Module):
    """
    Ensemble model combining ELECTRA discriminator with multiple classification heads.

    This module integrates an ELECTRA discriminator with an ensemble of classification
    heads for improved performance and robustness.
    """
    def __init__(self, num_ffs = 10, models_path = '', train_data_path = '', train_labels_path = '', vocab_path = '', device = 'cuda:0'):
        """
        Initialize the ElectraEnsemble model.

        Args:
            num_ffs (int): Number of feed-forward networks in the ensemble.
            models_path (str): Path to the pre-trained models.
            train_data_path (str): Path to the training data.
            train_labels_path (str): Path to the training labels.
            vocab_path (str): Path to the vocabulary embeddings.
            device (str): Device to use for computations.
        """
        super().__init__()

        self.load_disc_path = models_path + "/run1_epoch0_disc/pytorch_model.bin"
        self.load_embed_path = models_path + "/run1_epoch0_embed"
        self.config_path = models_path + "/run1_epoch0_disc/config.json"


        self.train_dataset = ELECTRADataset(np.load(train_data_path), vocab_path, np.load(train_labels_path))

        self.num_ffs = num_ffs
        self.device = device

        self.config = ElectraConfig.from_pretrained(self.config_path)
        self.electra = ElectraDiscriminator(self.config, torch.from_numpy(self.train_dataset.embeddings), self.load_disc_path, self.load_embed_path).to(self.device)
        self.ensemble = ElectraEnsembleHead(self.config, self.num_ffs, device = self.device)
        for i in range(self.num_ffs):
            path_i = models_path + "/run" + str(i+1) + "_epoch0_disc/pytorch_model.bin"
            di = torch.load(path_i)
            params_i = [di['classifier.dense.weight'], di['classifier.dense.bias'], di['classifier.out_proj.weight'], di['classifier.out_proj.bias']]
            print([di['classifier.dense.weight'].size(), di['classifier.dense.bias'].size(), di['classifier.out_proj.weight'].size(), di['classifier.out_proj.bias'].size()])
            self.ensemble.set_ff(i, params_i)

    def forward(self, x, labels = None, attention_mask = None):
        """
        Perform a forward pass through the ElectraEnsemble model.

        Args:
            x (torch.Tensor): Input tensor.
            labels (torch.Tensor, optional): Ground truth labels.
            attention_mask (torch.Tensor, optional): Attention mask for input sequence.

        Returns:
            tuple: Contains:
                - preds (torch.Tensor): Predictions from the ensemble.
                - z (torch.Tensor): Last hidden state from the ELECTRA discriminator.
                - pre_avg (list): Individual predictions from each classification head.
        """
        if labels is None:
            labels = torch.ones((x.size()[0], 1)).to(torch.long)
        loss, scores, z = self.electra(x.to(self.device), labels = labels.to(self.device), attention_mask = attention_mask)
        #print(scores)
        preds, pre_avg = self.ensemble(z)
        return preds, z, pre_avg

    def load_new_discriminator(self, path):
        """
        Load a new discriminator model from the given path.

        Args:
            path (str): Path to the new discriminator model weights.
        """
        self.electra = ElectraDiscriminator(self.config, torch.from_numpy(self.train_dataset.embeddings), path, self.load_embed_path).to(self.device)

In [16]:
batch_size = 1

# Load the ensemble model
ee = ElectraEnsemble(num_ffs = 10, models_path = '../data/ensemble',
                     train_data_path = '../data/microbiomedata/halfvarson_otu.npy',
                     train_labels_path = '../data/microbiomedata/halfvarson_IBD_labels.npy',
                     vocab_path = '../data/vocab_embeddings.npy',
                     device = 'cuda:0')

# Load the dataset
current_dataset = ELECTRADataset(np.load('../data/microbiomedata/IBD_train_otu.npy'),
                                 '../data/vocab_embeddings.npy',
                                 np.load('../data/microbiomedata/IBD_train_label.npy'))

data_loader = DataLoader(current_dataset, batch_size=batch_size, num_workers=0,shuffle=False)

ee.eval()

# Calculate base scores prior to leaving-one-out attribution
all_scores = []
data_iter = tqdm.tqdm(enumerate(data_loader),
                            desc="Process dataset",
                            total=len(data_loader),
                            bar_format="{l_bar}{r_bar}")
with torch.no_grad():
    for i, data in data_iter:
        input = data['electra_input'].to("cuda:0")[:, : 512]
        frequencies = data['species_frequencies'].to("cuda:0")[:, : 512]

        label = data['electra_label'].to("cuda:0")
        zero_boolean = torch.eq(frequencies ,0).to(device)
        mask = torch.ones(zero_boolean.shape, dtype=torch.float).to(device)
        mask = mask.masked_fill(zero_boolean, 0)
        pred, z, _ = ee(input, labels=label, attention_mask=mask)
        all_scores.append(pred[0][1].item())
    torch.save(torch.tensor(all_scores), "ibd_base_scores.pth")

FileNotFoundError: [Errno 2] No such file or directory: '../data/ensemble/run1_epoch0_embed'

In [7]:
# We only want to do attributions for classification decisions that the model is confident in, so we compute the 25th and 75th percentiles of the base scores
# and only do attributions for scores to the left of the 25th percentile and to the right of the 75th percentile.

all_scores = torch.load("ibd_base_scores.pth")
sorted_scores = torch.sort(torch.tensor(all_scores)).values
less_than_cutoff = sorted_scores[int(len(sorted_scores) * 0.25)].item()
greater_than_cutoff = sorted_scores[int(len(sorted_scores) * 0.75)].item()

FileNotFoundError: [Errno 2] No such file or directory: 'ibd_base_scores.pth'

In [2]:
def run_attribution(data_path, label_path, pretrained_path, vocab_path, length, less_than_cutoff, greater_than_cutoff, base_scores, epochs):
    """
    Calculate attributions for microbes in the dataset using leave-one-out method.

    Args:
        data_path (str): Path to the input data file.
        label_path (str): Path to the label file.
        pretrained_path (str): Path to pretrained models.
        vocab_path (str): Path to vocabulary embeddings.
        length (int): Maximum sequence length to consider.
        less_than_cutoff (float): Lower threshold for base scores.
        greater_than_cutoff (float): Upper threshold for base scores.
        base_scores (torch.Tensor): Pre-calculated base scores for each sample.
        epochs (int): Number of epochs the model was trained.

    Returns:
        tuple: Contains:
            - attribution_dict (dict): Dictionary of attributions for each microbe.
            - attribution_list (torch.Tensor): Tensor of all attributions.
    """
    # Initialize dictionaries to store attributions
    attribution_dict = {}
    attribution_list = []

    # Load dataset and create DataLoader
    current_dataset = ELECTRADataset(np.load(data_path),
                                     vocab_path,
                                     np.load(label_path))

    data_loader = DataLoader(current_dataset, batch_size=batch_size, num_workers=0,shuffle=False)

    # Initialize ElectraEnsemble model
    ee = ElectraEnsemble(num_ffs = 10, models_path = '/path/to/ensemble',
                         train_data_path = data_path,
                         train_labels_path = label_path,
                         vocab_path = vocab_path,
                         device = 'cuda:0')
    # Set model to evaluation mode
    ee.eval()
    data_iter = tqdm.tqdm(enumerate(data_loader),
                                desc="Process dataset",
                                total=len(data_loader),
                                bar_format="{l_bar}{r_bar}")

    print(type(data_iter), "cat")
    with torch.no_grad():
        for i, data in data_iter:
            if i % 50 == 0:
                print(i)
            # Skip samples with base scores that are not confident enough
            if not (base_scores[i] < less_than_cutoff or base_scores[i] > greater_than_cutoff):
                continue
            # Prepare input data
            input = data['electra_input'].to(device)[:, : 512]
            frequencies = data['species_frequencies'].to(device)[:, : 512]
            input_len = min(torch.sum(torch.ne(input, 26728)), length)
            label = data['electra_label'].to(device)

            # Create tensors for leave-one-out attribution
            del_input = torch.zeros((input_len+1, len(input[0]))).to(device).to(torch.int)
            del_frequencies = torch.zeros_like(del_input)
            del_label = (torch.ones((input_len+1, 1)).to(device) * label).to(torch.long)

            del_input[0] = input[0]
            del_frequencies[0] = frequencies[0]

            # Generate leave-one-out sequences
            for j in range(1, input_len+1):
                del_input[j][:j-1] = input[0][:j-1]
                del_input[j][j-1:] = torch.cat((input[0][j:], torch.tensor([26728], device=device)))
                del_frequencies[j][:j-1] = frequencies[0][:j-1]
                del_frequencies[j][j-1:] = torch.cat((frequencies[0][j:], torch.tensor([0], device=device)))

            # Create attention mask
            zero_boolean = torch.eq(del_frequencies ,0).to(device)
            mask = torch.ones(zero_boolean.shape, dtype=torch.float).to(device)
            mask = mask.masked_fill(zero_boolean, 0)

            # Get predictions for all leave-one-out sequences
            pred, z, _ = ee(del_input, labels=del_label, attention_mask=mask)
            # Calculate attributions
            attributions = [(pred[0][0] - pred[j+1][0]).item() for j in range(input_len)]
            # Store attributions in dictionary
            for j in range(input_len):
                microbe = input[0][j].item()
                if not microbe in attribution_dict:
                    attribution_dict[microbe] = [base_scores[i], []]
                attribution_dict[microbe][1].append(attributions[j])
                attribution_list.append([i, microbe, attributions[j], base_scores[i]])
    return attribution_dict, torch.tensor(attribution_list)

In [3]:
# Run attribution for IBD dataset (can easily be modified to run for other datasets)
ibd_attributions, ibd_at_tensor = run_attribution('../data/microbiomedata/IBD_train_otu.npy',
                                                  '../data/microbiomedata/IBD_train_label.npy',
                                                  '../data/pretrainedmodels',
                                                  '../data/vocab_embeddings.npy',
                                                  200,
                                                  1000,
                                                  -1000,
                                                  torch.load('ibd_base_scores.pth'),
                                                  epochs=120)

torch.save(ibd_at_tensor, "ibd_att_tensor.pth")

NameError: name 'torch' is not defined

We now switch our focus to computing the average embeddings for vocabulary elements on each of IBD, Halfvarson and Schirmer/HMP2 data.

In [8]:
# Load pretrained model (note that we no longer use the ensemble)
models_path = '../data/pretrainedmodels'
load_disc_path = models_path + "/5head5layer_epoch120_disc/pytorch_model.bin"
load_embed_path = None
config_path = models_path + "/5head5layer_epoch120_disc/config.json"

config = ElectraConfig.from_pretrained(config_path)
electra = ElectraDiscriminator(config, torch.from_numpy(current_dataset.embeddings), load_disc_path, load_embed_path).to(device)

NameError: name 'ElectraDiscriminator' is not defined

In [9]:
data_paths = [["IBD", '../data/total_IBD_512.npy', '../data/total_IBD_label.npy'],
["Halfvarson", '../data/halfvarson_512_otu.npy', '../data/halfvarson_IBD_labels.npy'],
["Schirmer", '../data/schirmer_IBD_512_otu.npy', '../data/schirmer_IBD_labels.npy']]

In [10]:
def get_average_embeddings(data_path, data_labels, name, ee, pretrained_path, vocab_path, vocab_size = 26727, embedding_size = 200, epochs = 120, device = "cuda:0"):
    """
    Calculate and save average embeddings for taxa in a dataset.

    This function processes a dataset, computes average embeddings for each taxon,
    and saves the results to a file.

    Args:
        data_path (str): Path to the input data file.
        data_labels (str): Path to the labels file.
        name (str): Name of the dataset (used for output file naming).
        ee (ElectraEnsemble): The ELECTRA ensemble model.
        pretrained_path (str): Path to pretrained models.
        vocab_path (str): Path to vocabulary embeddings.
        vocab_size (int): Size of the vocabulary (default: 26727).
        embedding_size (int): Size of the embeddings (default: 200).
        epochs (int): Number of epochs the model was trained (default: 120).
        device (str): Device to use for computations (default: "cuda:0").

    Returns:
        None (saves the average embeddings to a file)
    """
    path = pretrained_path + "/5head5layer_epoch" + str(epochs) + "_disc/pytorch_model.bin"
    ee.load_new_discriminator(path)
    ee.eval()
    current_dataset = ELECTRADataset(np.load(data_path), \
                                vocab_path, \
                                np.load(data_labels))

    data_loader = DataLoader(current_dataset, batch_size=1, num_workers=0,shuffle=False)

    data_iter = tqdm.tqdm(enumerate(data_loader),
                                desc="Process dataset",
                                total=len(data_loader),
                                bar_format="{l_bar}{r_bar}")

    taxa_sum_embedding = torch.zeros((vocab_size + 2, embedding_size)).to(device)
    num_taxa_summed = torch.zeros((vocab_size + 2, 1)).to(device)
    with torch.no_grad():
        for i, data in data_iter:
            input = data['electra_input'].to(device)
            frequencies = data['species_frequencies'].to(device)
            label = data['electra_label'].to(device)
            zero_boolean = torch.eq(frequencies ,0).to(device)
            mask = torch.ones(zero_boolean.shape, dtype=torch.float).to(device)
            mask = mask.masked_fill(zero_boolean, 0)
            pred, z, _ = ee(input, labels=label, attention_mask=mask)

            taxa_sum_embedding[input[0]] += z[0]
            num_taxa_summed[input[0]] += 1
        taxa_sum_embedding = taxa_sum_embedding / num_taxa_summed
        taxa_sum_embedding = taxa_sum_embedding[:-2]
    torch.save(taxa_sum_embedding.to("cpu"), "epoch_" + str(epochs) + "_" + name + "_avg_vocab_embeddings.pth")

for d in data_paths:
    get_average_embeddings(d[1], d[2], d[0], ee, '/path/to/pretrainedmodels', '/path/to/vocab_embeddings.npy', device="cuda:0")

NameError: name 'ee' is not defined