The code in this notebook is intended to use data from the Cell x Gene Census to train a model to predict how other genes might have changed given a perturbation of one or more other genes. It acquires the data (currently all naive B cell primary data), processes it to pair up cells and select a gene that differs in expression between them to perturb, and trains the model using this data.

The model is an encoder-decoder transformer, which encodes the expression levels of genes in the source cell and then uses the encoded data as the keys and values to do cross-attention with queries derived from the perturbation. Genes are embedded using learnable embeddings, with the embedding then multiplied by the expression level using the ExpressionEmbedding class.

The original inspiration for this was the idea of a model that could take an original cell state and a perturbation of one or more genes, and predict the consequences of that perturbation. This model is not capable of that, since the data it is provided with is not actual perturbation data and so it can only ever learn correlations, but given proper training data the architecture should be suitable.

This is also still in development; at present, the loss fails to significantly decrease after five epochs. This is most likely due to poor data (naively selecting the first hundred protein-coding genes in the list leads to most having no expression or not being measured, judging by examples pulled from the dataloader, and pairs of cells selected for perturbation may have few measured genes in common), which could be resolved by better selection of genes and by an improved dataloader capable of caching cell examples until an appropriate pairing is found for them. However, it could also be due to the learning rate, or other factors.

In [None]:
#@title install_packages_&_import_libraries

%pip install --quiet -U cellxgene_census
%pip install --quiet tiledbsoma_ml

import cellxgene_census
import tiledbsoma as soma
from sklearn.preprocessing import LabelEncoder
from tiledbsoma_ml import ExperimentDataset
from tiledbsoma_ml import experiment_dataloader
import pandas as pd
import anndata
import numpy as np
import scanpy as sc

import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import time
import math
import os

In [None]:
#@title get_census_data

#See https://chanzuckerberg.github.io/cellxgene-census/notebooks/experimental/pytorch.html#Create-an-ExperimentDataPipe for guidance
#on loading data.

with cellxgene_census.open_soma(census_version = "2025-01-30") as census:
  experiment = census["census_data"]["homo_sapiens"]
  experiment_rna = experiment.ms["RNA"]
  var_df = experiment_rna.var.read().concat().to_pandas()
  gene_id_map = var_df.set_index("feature_id")["soma_joinid"].to_dict()
  del var_df
  datasets_df = census["census_info"]["datasets"].read().concat().to_pandas()
  datasets_id_map = datasets_df.set_index("dataset_id")["soma_joinid"].to_dict()
  del datasets_df
  presence_matrix = cellxgene_census.get_presence_matrix(census, "homo_sapiens")
  #gene_id_map, datasets_id_map, and presence_matrix will be used later
  #to mask out genes that weren't measured from attention, perturbation selection, and loss.
  obs_filter = "is_primary_data == True and cell_type == 'naive B cell'"
  var_filter = "feature_type == 'protein_coding'"

  with experiment.axis_query(
    measurement_name="RNA",
    obs_query=soma.AxisQuery(value_filter=obs_filter),
    var_query=soma.AxisQuery(value_filter=var_filter)
  ) as query:
    #Set up experimental dataset.
    experiment_dataset = ExperimentDataset(
      query,
      layer_name="raw",
      obs_column_names=["cell_type", "dataset_id"],
      batch_size=128,
      shuffle=True,
      seed=111,
    )
    obs_df = query.obs(column_names=["cell_type", "dataset_id"]).concat().to_pandas()
    var_df = query.var(column_names=["feature_id", "feature_length"]).concat().to_pandas()
    #Cell type is not currently used, since all the data originates from the same cell type,
    #but in future the pairing function could be modified to select changes in expression from within a cell type,
    #perhaps using a custom dataloader to cache examples drawn from the provided dataloader
    #until a suitable pairing can be found for them.

train_dataset, val_dataset = experiment_dataset.random_split(
  0.8,
  0.2,
  seed=111,
)

train_dataloader = experiment_dataloader(train_dataset)
val_dataloader = experiment_dataloader(val_dataset)

#The above dataloaders produce entries with format (X, obs) where X is an array of counts
#and obs is a dataframe containing cell type and dataset id.
#Even filtering for protein-coding genes, however, there are ~20,000 genes, which is
#too many given each gene attends to each other gene so the cost of self-attention
#grows quadratically with gene count.

#Currently, I am filtering for the first 100 genes, selected arbitrarily; future versions might
#identify genes which are commonly measured and expressed and filter for those.

valid_gene_indices = torch.arange(100) #(valid_gene_count)

#Establish gene ID vocabulary (for selected genes only).
gene_list = var_df.iloc[list(range(100)), 0].tolist()
gene_vocab = {gene_ID:idx for idx, gene_ID in enumerate(gene_list)}
#This vocabulary matches the Ensembl ID of each gene with its index *as provided in the filtered data*.
#Contrast gene_id_map, which provides a similar feature for the index of a gene in the full Census,
#allowing indexing into presence_matrix.

In [None]:
#@title data_processing_setup

#First, we need to normalize input to derive TPM values rather than raw counts.
#Then, we take log(TPM+1) to put the values on a more sensible scale.
#Adding 1 sets the scale with 0 at the bottom,
#since the raw counts are non-negative integers.

def raw_to_log_TPM(counts_array, var_df, valid_genes = None):
  "Converts raw counts to log(TPM+1). If valid_genes is provided, also returns only TPMs for valid genes."
  #counts_array should be a tensor (batch_size, gene_count)
  #valid_genes should be a tensor (valid_gene_count) of indices of valid genes.
  lengths = torch.tensor(var_df['feature_length']) / 1000 #(gene_count), converted to kilobases.
  RPK = counts_array / lengths.unsqueeze(0) #(batch_size, gene_count), RPK = reads per kilobase.
  total_counts = RPK.sum(dim=1) #(batch_size)
  TPM = (RPK * 1e6) / total_counts.unsqueeze(1) #(batch_size, gene_count), TPM = transcripts per million.
  if valid_genes is not None:
    TPM = torch.index_select(TPM, dim=1, index=valid_genes.to(dtype = torch.long))
  return torch.log(TPM + 1) #(batch_size, valid_gene_count)

#Masking function, returns a boolean tensor mask with True on measured genes.

def generate_presence_mask(batch, obs_batch, var_df, presence_matrix, gene_vocab, datasets_id_map, gene_id_map):
  "Returns a boolean mask for the provided batch with True on entries containing measured genes."
  #batch should be a tensor of shape (batch_size, gene_count).
  #batch must *not* be a pair_ or perturb_batch, or else the indexing will not work properly.
  #obs_batch should be a pandas dataframe with a column named dataset_id.
  #presence_matrix should be a sparse NP array of shape (num_datasets, total_num_genes).
  #gene_vocab should be a dictionary of feature_id:index in the input tensor.
  #The ID maps should be dictionaries from dataset/ensmebl id to index in the presence matrix.
  inv_gene_vocab = {id:gene for gene, id in gene_vocab.items()}
  dataset_ids = obs_batch['dataset_id'].to_list()
  dataset_indices = [datasets_id_map[id] for id in dataset_ids]
  gene_ensembl_ids = [inv_gene_vocab[id] for id in range(batch.shape[1])]
  gene_indices = [gene_id_map[ensembl_id] for ensembl_id in gene_ensembl_ids]
  mask = presence_matrix[np.array(dataset_indices)][:, np.array(gene_indices)].copy()
  return torch.tensor(mask.toarray(), dtype = torch.bool)

#We now need a function to take processed expression levels and generate the pairs
#and perturbations that will actually be fed to the model.
#For now we will *ignore cell type* and just take arbitrary pairs from within a batch;
#current data is all from the cell type, so this is not a huge issue,
#but an improved version could integrate a version of this function into a custom dataloader
#and cache cells until a suitable match is found for them.

def process_to_pairs(expression_batch, obs_batch, perturbed_gene_count):
  """Takes a batch, pairs up expression levels within it, and returns the pair, a perturbation between them, a presence mask
  for the pair, and indices where no perturbation was possible."""
  #expression_batch should be a tensor (batch_size, valid_gene_count)
  #perturbed genes is an integer, the gene to perturb will be selected randomly,
  #from all genes that are measured in both pair members and have non-identical expression.
  assert expression_batch.shape[0] % 2 == 0, "Batch size must be divisible by 2"
  presence_mask = generate_presence_mask(expression_batch, obs_batch, var_df, presence_matrix, gene_vocab, datasets_id_map, gene_id_map)
  pair_batch = expression_batch.unsqueeze(1).reshape(-1, 2, expression_batch.shape[-1]) #(batch_size/2, 2, valid_gene_count)
  pair_presence_mask = presence_mask.unsqueeze(1).reshape(-1, 2, expression_batch.shape[-1]) #(batch_size/2, 2, valid_gene_count)
  all_perturbs = pair_batch[:, 1, :] - pair_batch[:, 0, :] #(batch_size/2, valid_gene_couht)
  #Get perturbation mask.
  perturb_nonzero_mask = (all_perturbs != 0) #(batch_size/2, valid_gene_couht), identifies genes where there is nonzero difference in expression.
  perturb_measured_mask = torch.logical_and(pair_presence_mask[:, 0, :], pair_presence_mask[:, 1, :]) #(batch_size/2, valid_gene_couht), identifies
  #genes which were measured in both members of the pair.
  perturb_double_mask = torch.logical_and(perturb_nonzero_mask, perturb_measured_mask)
  perturb_rand_mask = torch.rand(all_perturbs.shape) * perturb_double_mask.to(torch.float)
  _, perturb_topk = torch.topk(perturb_rand_mask, perturbed_gene_count, dim = 1)
  #The above is a slightly hacky way of getting a random k genes from among the subset identified by the mask, but it works.
  #Generate a mask that can be used to zero out all genes not selected for perturbation:
  perturb_mask = torch.zeros(all_perturbs.shape)
  perturb_mask[torch.arange(perturb_mask.shape[0]).unsqueeze(1).expand(-1, perturb_topk.shape[1]).to(torch.long), perturb_topk.to(torch.long)] = 1
  #Detect rows with no valid perturbation targets:
  no_perturb_indices = (torch.sum(perturb_double_mask.to(torch.float), dim = 1) == 0) #(batch_size/2)
  #Apply mask and return outputs.
  perturb_batch = (all_perturbs*perturb_mask).to(torch.float)
  perturb_batch = torch.where(perturb_batch.abs() < 1e-6, 0.0, perturb_batch) #Clean up -0s to avoid possible weirdness down the line,
  #I don't think it would actually cause any problems but it looks ugly.
  return pair_batch, perturb_batch, pair_presence_mask, no_perturb_indices

#Often, even if genes were measured, they have no expression in either cell, suggesting that they may be irrelevant
#to the perturbation of interest. This takes up compute during attention for limited gain.
#In practice this appears to have little effect on performance, but it may increase noise during training, so I've
#included masking for genes with no expression in either pair member.
#Importantly, this mask will not be used for the loss since we still want the network to predict results for these genes.
#It also is only applied to the encoder portion, again since all gene should be predicted.

def generate_pair_mask(pair_batch):
  #pair_batch should be a tensor of shape (batch_size/2, 2, valid_gene_count).
  return torch.logical_and(pair_batch[:, 0, :] == 0, pair_batch[:, 1, :] == 0).unsqueeze(1).expand(-1, 2, -1)

#Now let's wrap all these up into a single function that we can call easily.

def process_batch(X, obs_batch, perturbation_count):
  #X should have shape (batch_size, valid_gene_count) and be a Numpy array.
  X = torch.tensor(X, dtype = torch.float)
  X = raw_to_log_TPM(X, var_df, valid_gene_indices)
  pairs, perturbs, pair_presence_mask, no_perturb_indices = process_to_pairs(X, obs_batch, perturbation_count)
  #Generate a mask for genes that aren't expressed in either member of a pair.
  pair_zeros_mask = generate_pair_mask(pairs) #(batch_size/2, valid_gene_count), True on double-zero genes.
  #Create a combined mask for source non-measured or doubly non-expressed genes.
  #This will be used for attention masking, but *not* for loss
  #since we still want the network to predict non-expressed genes.
  pair_mask = torch.logical_and((pair_presence_mask == False), pair_zeros_mask) #(batch_size/2, 2, valid_gene_count), True on genes that should be masked.
  #Create a combined mask to block out pairs where no perturbation was possible.
  #This will be used for the loss, but *not* for attention masking
  #since an attention mask with all inputs masked yields nan.
  combined_mask = torch.logical_and(pair_presence_mask, (no_perturb_indices == False).unsqueeze(-1).unsqueeze(-1).expand(-1, 2, pair_presence_mask.shape[2]))
  #Return the pair members, the perturbation, the pair presence mask for attention, and the combined mask for use with loss.
  return pairs[:,0,:], pairs[:,1,:], perturbs, pair_mask, combined_mask #first three (batch_size/2, valid_gene_count), fourth & fifth (batch_size/2, 2, valid_gene_count)

In [None]:
#@title model_setup

#Convenience functions.
def cpu():
  return torch.device('cpu')

def gpu(i=0):
  return torch.device(f'cuda:{i}')

def num_gpus():
  return torch.cuda.device_count()

def try_gpu(i=0):
  if num_gpus() >= i+1:
    return gpu(i)
  return cpu()

#Custom embedding for gene expression levels.
class ExpressionEmbedding(nn.Module):
  'Embeds numeric gene IDs and multiplies the embedded vectors by the provided expression levels.'
  def __init__(self, gene_count, embedding_dim):
    super().__init__()
    self.embed = nn.Embedding(gene_count, embedding_dim)

  def forward(self, X, genes):
    #X should be a tensor of expression levels of shape (batch_size, gene_count).
    #genes should be a tensor of numeric tokens of shape (gene_count).
    embedded_genes = self.embed(genes) #(gene_count, n_hiddens)
    embedded_counts = X.unsqueeze(-1) * embedded_genes.unsqueeze(0) #(batch_size, gene_count, n_hiddens)
    return embedded_counts

#Convenience class, derived from the D2L Dive Into Deep Learning textbook with some modifications.
class AddNorm(nn.Module):
  def __init__(self, norm_shape, dropout = 0, use_dropout = False):
    super().__init__()
    self.dropout = nn.Dropout(dropout) if use_dropout else nn.Identity()
    self.ln = nn.LayerNorm(norm_shape)

  def forward(self, X, Y):
    return self.ln(self.dropout(Y) + X)

#Building block of the encoder. Uses the pair mask to mask out genes from attention.
class TransformerGeneEncoderBlock(nn.Module):
  def __init__(self, n_hiddens, ff_n_hiddens, n_heads, batch_first, dropout, norm_dropout = False):
    super().__init__()
    self.attention = nn.MultiheadAttention(n_hiddens, n_heads, batch_first=batch_first, dropout=dropout)
    self.addnorm1 = AddNorm(n_hiddens, dropout, use_dropout = norm_dropout)
    self.ffn = nn.Sequential(
        nn.Linear(n_hiddens, ff_n_hiddens),
        nn.ReLU(),
        nn.Linear(ff_n_hiddens, n_hiddens)
    )
    self.addnorm2 = AddNorm(n_hiddens, dropout, use_dropout = norm_dropout)

  def forward(self, X, pair_mask = None):
    Y, _ = self.attention(X, X ,X, key_padding_mask = pair_mask)
    Y = self.addnorm1(X, Y)
    return self.addnorm2(Y, self.ffn(Y))

#The complete encoder. Note that positional encoding is not used,
#since no information should be provided by the order of the genes.
class TransformerGeneEncoder(nn.Module):
  def __init__(self, vocab_size, n_hiddens, ff_n_hiddens, n_heads, n_blks, dropout, norm_dropout = False, batch_first = True):
    super().__init__()
    self.embed = ExpressionEmbedding(vocab_size, n_hiddens)
    self.n_hiddens = n_hiddens
    self.blks = nn.Sequential()
    for i in range(n_blks):
      self.blks.add_module("block"+str(i), TransformerGeneEncoderBlock(n_hiddens, ff_n_hiddens, n_heads, batch_first, dropout, norm_dropout))

  def forward(self, X, genes, pair_mask = None):
    X = self.embed(X, genes) * math.sqrt(self.n_hiddens)
    for i, blk in enumerate(self.blks):
      X = blk(X, pair_mask)
    return X

#The building block of the decoder.
class TransformerGeneDecoderBlock(nn.Module):
  def __init__(self, n_hiddens, ff_n_hiddens, n_heads, batch_first, dropout, i, norm_dropout = False, use_cross_attention = True):
    super().__init__()
    self.i = i
    self.attention1 = nn.MultiheadAttention(n_hiddens, n_heads, batch_first=batch_first, dropout=dropout)
    self.addnorm1 = AddNorm(n_hiddens, dropout, use_dropout = norm_dropout)
    self.attention2 = nn.MultiheadAttention(n_hiddens, n_heads, batch_first=batch_first, dropout=dropout) if use_cross_attention else None
    self.addnorm2 = AddNorm(n_hiddens, dropout, use_dropout = norm_dropout) if use_cross_attention else None
    self.ffn = nn.Sequential(
        nn.Linear(n_hiddens, ff_n_hiddens),
        nn.ReLU(),
        nn.Linear(ff_n_hiddens, n_hiddens)
    )
    self.addnorm3 = AddNorm(n_hiddens, dropout, use_dropout = norm_dropout)
    self.n_heads = n_heads
    self.use_cross_attention = use_cross_attention

  def forward(self, X, state):
    #State has structure [encoder_outputs, encoder_pair_mask, decoded_outputs].
    #Some of this is actually redundant - this isn't a sequence model so there's no need to store decoded outputs,
    #but I'm leaving the structure unchanged for now in case I find a use for it later.
    enc_outputs, enc_pair_mask = state[0], state[1]
    if state[2][self.i] is None:
      key_values = X
    else:
      key_values = torch.cat((state[2][self.i], X), dim = 1)
    state[2][self.i] = key_values

    X2, _ = self.attention1(X, key_values, key_values)
    Y = self.addnorm1(X, X2)
    if self.use_cross_attention:
      Y2, _ = self.attention2(Y, enc_outputs, enc_outputs, key_padding_mask = enc_pair_mask)
      Z = self.addnorm2(Y, Y2)
      return self.addnorm3(Z, self.ffn(Z)), state
    else:
      return self.addnorm3(Y, self.ffn(Y)), state

#The full decoder. As with the encoder, no positional encoding is used.
class TransformerGeneDecoder(nn.Module):
  def __init__(self, vocab_size, n_hiddens, ff_n_hiddens, n_heads, n_blks, dropout, norm_dropout = False, batch_first = True, use_cross_attention = True):
    super().__init__()
    self.embed = ExpressionEmbedding(vocab_size, n_hiddens)
    self.n_hiddens = n_hiddens
    self.n_blks = n_blks
    self.blks = nn.Sequential()
    for i in range(n_blks):
      self.blks.add_module("block"+str(i), TransformerGeneDecoderBlock(n_hiddens, ff_n_hiddens, n_heads, batch_first, dropout, i, norm_dropout, use_cross_attention))

  def init_state(self, enc_outputs, enc_pair_mask = None):
    return [enc_outputs, enc_pair_mask, [None]*self.n_blks]

  def forward(self, X, genes, state):
    X = self.embed(X, genes) * math.sqrt(self.n_hiddens)
    for i, blk in enumerate(self.blks):
      X, state = blk(X, state)
    return X, state

#The full model. The reduction of representations to a single value might well be an issue for overall performance; however,
#it's serviceable for now.
class GenePerturbationTransformer(nn.Module):
  def __init__(self, vocab_size, n_hiddens, ff_n_hiddens, n_enc_heads, n_enc_blks, n_dec_heads, n_dec_blks, dropout, norm_dropout = False):
    super().__init__()
    self.encoder = TransformerGeneEncoder(vocab_size, n_hiddens, ff_n_hiddens, n_enc_heads, n_enc_blks, dropout, norm_dropout)
    self.decoder = TransformerGeneDecoder(vocab_size, n_hiddens, ff_n_hiddens, n_dec_heads, n_dec_blks, dropout, norm_dropout, use_cross_attention=True)
    self.out = nn.Linear(n_hiddens, 1)

  def forward(self, X, P, genes, pair_mask = None):
    #X is source expressions, P is perturbations.
    X = self.encoder(X, genes)
    state = self.decoder.init_state(X, pair_mask)
    Z, _ = self.decoder(P, genes, state)
    return self.out(Z).squeeze(-1) #Shape (batch_size, gene_count), batch_size /2 from original and only valid genes.

In [None]:
#@title training_loop

#A notable potential improvement here is that based on testing with next(iter()), initializing an instance of the dataloader for the data used in this version
#takes up to 10 minutes. This means it might be possible to set it up to happen asynchronously on the CPU while the GPU is busy training the model
#and the CPU's load is relatively light, potentially saving significant amounts of time.

perturbed_gene_count = 1

n_hiddens = 128
ff_n_hiddens = 512
n_enc_heads = 4
n_enc_blks = 4
n_dec_heads = 4
n_dec_blks = 8
dropout = 0.2

lr = 1e-3

max_epochs = 5
device = try_gpu()

genes = torch.tensor([gene_vocab[gene_ID] for gene_ID in gene_list])
model = GenePerturbationTransformer(genes.shape[0], n_hiddens, ff_n_hiddens, n_enc_heads, n_enc_blks, n_dec_heads, n_dec_blks, dropout).to(device)
genes = genes.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss() #This may not be the best loss function, but it's fine for now.

#For loss tracking:
train_length = len(train_dataloader)
val_length = len(val_dataloader)

print("Beginning training.")

for epoch in range(max_epochs):
    total_train_loss = torch.tensor(0.0, device = device)
    total_val_loss = torch.tensor(0.0, device = device)
    time_initial = time.time()
    model.train()
    print(f"Epoch {epoch}: Training ... \t Time:{time.time()-time_initial}")
    for X, obs in train_dataloader:
        if X.shape[0] != 128: #Handles last batch being smaller, since drop_last isn't available from tiledbsoma_ml's experiment dataloader.
          continue            #Fixing this would be yet another benefit of a custom dataloader!
        X, Z, P, pair_mask, combined_mask = process_batch(X, obs, perturbed_gene_count)
        X, Z, P, pair_mask, combined_mask = X.to(device), Z.to(device), P.to(device), pair_mask.to(device), combined_mask.to(device)
        source_mask = pair_mask[:, 0, :] #Restricts attention to genes that were measured in the source cell.
        Y = model(X, P, genes, source_mask)
        #Mask out genes that weren't measured in the target and batches where no perturbation occurred.
        #Remember combined_mask has True where genes should not be blocked.
        loss_mask = combined_mask[:, 1, :].to(torch.float)
        loss_scale_factor = loss_mask.numel() / torch.sum(loss_mask) #Rescale to account for ignoring some elements,
        #will cause problems if loss_mask is entirely zeros but this is unlikely to happen. I may fix it regardless in
        #a future version.
        Y, Z = Y * loss_mask, Z * loss_mask
        loss = loss_fn(Y, Z) * loss_scale_factor
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.detach()
    model.eval()
    print(f"Epoch {epoch}: Validating... \t Time:{time.time()-time_initial}")
    for X, obs in val_dataloader:
        if X.shape[0] != 128:
          continue
        X, Z, P, pair_mask, combined_mask = process_batch(X, obs, perturbed_gene_count)
        X, Z, P, pair_mask, combined_mask = X.to(device), Z.to(device), P.to(device), pair_mask.to(device), combined_mask.to(device)
        source_mask = pair_mask[:, 0, :]
        with torch.no_grad():
          Y = model(X, P, genes, source_mask)
          loss_mask = combined_mask[:, 1, :].to(torch.float)
          loss_scale_factor = loss_mask.numel() / torch.sum(loss_mask)
          Y, Z = Y * loss_mask, Z * loss_mask
          loss = loss_fn(Y, Z) * loss_scale_factor
          total_val_loss += loss.detach()
    avg_train_loss = total_train_loss.item()/train_length
    avg_val_loss = total_val_loss.item()/val_length
    time_elapsed = time.time() - time_initial
    print(f"Epoch {epoch}: Training Loss {avg_train_loss:.4f}, Validation Loss {avg_val_loss:.4f}, Time Elapsed {time_elapsed}")

print("Training complete!")

In [None]:
#@title download_weights

from google.colab import files

model.to(cpu())
filename = 'gene_perturbation_correlate_predictor_weights_v1_5_epochs.txt'
torch.save(model.state_dict(), filename)
files.download(filename)

In [None]:
#@title download_vocab

import json

model_key = 'GPCP_v1_e5'

with open(f'gene_vocab_{model_key}.json', 'w') as f:
  json.dump(gene_vocab, f)

files.download(f'gene_vocab_{model_key}.json')