# Transformer model for predicting modalities in scRNA-seq

**Authors**<br>Vedu Mallela: GiwoTech, vedu.mallela@gmail.com<br>Simon Lee: UC Santa Cruz, siaulee@ucsc.edu

# Goal of the code

**TODO: explain algorithm**

# Libraries 

Import all files and modules for this competition<br>
*below will provide documentation of the following libraries*<br>
<br>
**scanpy** (**s**ingle **c**ell **an**alysis in **Py**thon) - https://scanpy.readthedocs.io/en/stable/ <br>
**anndata** (**ann**otated **data**) - https://anndata.readthedocs.io/en/latest/ <br>
**matplotlib** - https://matplotlib.org/ <br>
**numpy** - https://numpy.org/doc/stable/ <br>
**pandas** - https://pandas.pydata.org/ <br>
**logging** - https://docs.python.org/3/howto/logging.html <br>
**sklearn** - https://scikit-learn.org/stable/ <br>
<br>
*code begins here*

In [1]:
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import logging

from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# LOAD DATA

In [2]:
path = '/tmp/public/multiome/'
outpath ='../out/'
adata_gex = ad.read_h5ad(path + "multiome_gex_processed_training.h5ad")
adata_atac = ad.read_h5ad(path + "multiome_atac_processed_training.h5ad")

After successfully loading in the data, we can try to begin plotting the batch for the **Assay for Transposase-Accessible Chromatin using sequencing** (ATAC-seq) and **Gene Expression** (GEX) data on the umap interface. These umap projections will be saved to GEX.pdf and ATAC.pdf

In [3]:
#sc.tl.pca(adata_gex)
#sc.pl.umap(adata_gex, color=['batch'],save='_GEX', title='GEX umap Display')
#sc.tl.pca(adata_atac)
#sc.pl.umap(adata_atac, color=['batch'], layer='counts', save='_ATAC', title='ATAC umap Display')

Next we are going to check out all the indivdual cell types occuring in both the ATAC and GEX data. This way we can see all the types of cells from this dataset provided.

In [4]:
#sc.tl.pca(adata_gex)
#sc.pl.umap(adata_gex, color='cell_type',save='_GEX_ct', title='GEX Cell Type umap')
#sc.tl.pca(adata_atac)
#sc.pl.umap(adata_atac, color='cell_type',save='_ATAC_ct', title='ATAC Cell Type umap')

In [8]:
# filter out the data
# Convert anndata objects to dataframes and filter.  
# Genes that show up in < 1% cells are dropped.
# Atac seq data that shows up in < 1% cells are dropped 

gex_df = adata_gex.to_df()
atac_df = adata_atac.to_df()

mask = gex_df>0
total_cells = gex_df.shape[0]
maskdf = mask.sum(axis=0)/total_cells*100 <=1

gex_feature_drop  = list(maskdf.loc[maskdf==True].index.values)

mask = atac_df>0
maskdf = mask.sum(axis=0)/total_cells <=0.01

atac_feature_drop  = list(maskdf.loc[maskdf==True].index.values)

gex_ = gex_df.drop(columns=gex_feature_drop)
atac_ = atac_df.drop(columns=atac_feature_drop)

print('Filtered data set')
print('GEX data: Total cells=' + str(gex_.shape[0]) + ', Number features=' + str(gex_.shape[1]))
print('ATAC data: Total cells=' + str(atac_.shape[0]) + ', Number Features=' + str(atac_.shape[1]))

Filtered data set
GEX data: Total cells=22463, Number features=12160
ATAC data: Total cells=22463, Number Features=53020


In [5]:
from transformers import AutoTokenizer, AutoModelWithLMHead, T5ForConditionalGeneration
from pathlib import Path
from sklearn.model_selection import train_test_split
import torch

In [None]:
#convert to numpy array
atac_numpy_array = np.array(gex_)

In [6]:
# read the sequence and split it into chunks
def read_seq_split(split_dir): 
    split_dir = (split_dir)
    texts = []
    labels = []
    for label_dir in ["gex", "atac"]: # for each label
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())

    return texts, labels

In [7]:
from transformers import AutoTokenizer, AutoModelWithLMHead, T5ForConditionalGeneration, T5Tokenizer
from pathlib import Path
from sklearn.model_selection import train_test_split
import torch
#gex_df = adata_gex.to_df()
#X = gex_df.drop(['target'],axis=1).values   # independant features
#y = gex_df['target'].values                 # dependant variable
#train_texts, val_texts, train_labels, val_labels = train_test_split(x, y, train_texts, train_labels, test_size=.2)
#train_text, train_labels = read_seq_split(gex_df)
#train_text =  

tokenizer = T5Tokenizer.from_pretrained("t5-small")
train_encodings = tokenizer(adata_gex, truncation=True, padding=True)
#val_encodings = tokenizer(val_texts, truncation=True, padding=True)
#test_encodings = tokenizer(test_texts, truncation=True, padding=True)

ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [9]:
#train_texts, train_labels = read_seq_split(adata_gex) 
#test_texts, test_labels = read_seq_split('figures/test')

TypeError: expected str, bytes or os.PathLike object, not AnnData

In [22]:
adata_atac.obs

Unnamed: 0,nCount_peaks,atac_fragments,reads_in_peaks_frac,blacklist_fraction,nucleosome_signal,cell_type,pseudotime_order_ATAC,batch
TAGTTGTCACCCTCAC-1-s1d1,4031.0,5400,0.746481,0.003473,0.642468,Naive CD20+ B,,s1d1
CTATGGCCATAACGGG-1-s1d1,8636.0,19266,0.448251,0.003126,1.220679,CD14+ Mono,,s1d1
CCGCACACAGGTTAAA-1-s1d1,4674.0,6177,0.756678,0.001284,0.692573,CD8+ T,,s1d1
TCATTTGGTAATGGAA-1-s1d1,2803.0,4019,0.697437,0.000714,0.633838,CD8+ T,,s1d1
ACCACATAGGTGTCCA-1-s1d1,1790.0,2568,0.697040,0.003352,0.727660,CD16+ Mono,,s1d1
...,...,...,...,...,...,...,...,...
TAGTAAGCAACTAGGG-8-s3d6,7239.0,10580,0.684216,0.000553,0.866142,HSC,0.23367,s3d6
TGGTCCTTCGGCTAGC-8-s3d6,16056.0,22771,0.705107,0.000810,0.964006,CD4+ T activated,,s3d6
CGCTTGCGTTGTTGGA-8-s3d6,8388.0,14137,0.593337,0.000358,1.215227,pDC,,s3d6
ACCCTCCCAGCCAGTT-8-s3d6,1001.0,1381,0.724837,0.000000,0.714286,CD8+ T,,s3d6


We are now going to print out the number of observations and features of our GEX and ATAC-seq data. 

Few things to note before we proceed:



In [14]:
print(f"The GEX data has {adata_gex.n_obs} observations and {adata_gex.n_vars} features.")
print(f"The ATAC data has {adata_atac.n_obs} observations and {adata_atac.n_vars} features.")

The GEX data has 22463 observations and 13431 features.
The ATAC data has 22463 observations and 116490 features.


In [18]:
# filter out the data
# Convert anndata objects to dataframes and filter.  
# Genes that show up in < 1% cells are dropped.
# Atac seq data that shows up in < 1% cells are dropped 

gex_df = adata_gex.to_df()
atac_df = adata_atac.to_df()

mask = gex_df>0
total_cells = gex_df.shape[0]
maskdf = mask.sum(axis=0)/total_cells*100 <=1

gex_feature_drop  = list(maskdf.loc[maskdf==True].index.values)

mask = atac_df>0
maskdf = mask.sum(axis=0)/total_cells <=0.01

atac_feature_drop  = list(maskdf.loc[maskdf==True].index.values)

gex_ = gex_df.drop(columns=gex_feature_drop)
atac_ = atac_df.drop(columns=atac_feature_drop)

print('Filtered data set')
print('GEX data: Total cells=' + str(gex_.shape[0]) + ', Number features=' + str(gex_.shape[1]))
print('ATAC data: Total cells=' + str(atac_.shape[0]) + ', Number Features=' + str(atac_.shape[1]))

Filtered data set
GEX data: Total cells=22463, Number features=12160
ATAC data: Total cells=22463, Number Features=53020


In [19]:
gex_

Unnamed: 0,AL627309.5,LINC01409,LINC01128,NOC2L,ISG15,C1orf159,SDF4,B3GALT6,UBE2J2,ACAP3,...,MT-ATP8,MT-ATP6,MT-CO3,MT-ND3,MT-ND4L,MT-ND4,MT-ND5,MT-ND6,MT-CYB,AL592183.1
TAGTTGTCACCCTCAC-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,4.410295,0.000000,4.410295,0.000000,0.000000,4.410295,0.0
CTATGGCCATAACGGG-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,2.194758,0.0,...,0.000000,0.000000,2.194758,0.000000,0.000000,0.000000,0.000000,0.000000,13.168547,0.0
CCGCACACAGGTTAAA-1-s1d1,0.410619,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.410619,0.821238,3.284951,0.410619,0.410619,0.821238,1.231857,0.410619,3.284951,0.0
TCATTTGGTAATGGAA-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,2.879966,11.519863,0.000000,0.000000,0.000000,5.759931,0.000000,0.000000,0.0
ACCACATAGGTGTCCA-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,3.743880,13.103581,1.871940,0.000000,3.743880,1.871940,0.000000,9.359701,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TAGTAAGCAACTAGGG-8-s3d6,0.000000,0.0,0.0,0.0,1.915288,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
TGGTCCTTCGGCTAGC-8-s3d6,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
CGCTTGCGTTGTTGGA-8-s3d6,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,3.427911,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
ACCCTCCCAGCCAGTT-8-s3d6,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0


In [20]:
atac_

Unnamed: 0,chr1-181117-181803,chr1-629497-630394,chr1-633515-634474,chr1-778276-779191,chr1-816868-817761,chr1-827067-827948,chr1-842497-843414,chr1-869472-870377,chr1-904343-905196,chr1-906441-907357,...,GL000205.2-88673-89483,GL000205.2-140307-141166,GL000195.1-30407-31261,GL000195.1-32211-33062,GL000219.1-39933-40839,GL000219.1-42172-43054,GL000219.1-44703-45584,GL000219.1-45726-46450,GL000219.1-99257-100160,KI270713.1-21434-22336
TAGTTGTCACCCTCAC-1-s1d1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
CTATGGCCATAACGGG-1-s1d1,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CCGCACACAGGTTAAA-1-s1d1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TCATTTGGTAATGGAA-1-s1d1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ACCACATAGGTGTCCA-1-s1d1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TAGTAAGCAACTAGGG-8-s3d6,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TGGTCCTTCGGCTAGC-8-s3d6,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
CGCTTGCGTTGTTGGA-8-s3d6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0
ACCCTCCCAGCCAGTT-8-s3d6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# TRANSFORMER

The Transformer T-5 small model will take in a custom dataset. This model relies solely on training therefore it is important that we have the proper pretraining before turning in this method for single cell sequencing analysis.

In [2]:
# read the sequence and split it into chunks
def read_seq_split(split_dir): 
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["", ""]: # for each label
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())

    return texts, labels

Loading up the data so the sequences can be read

Train our transformer using this the train_test_split() function. This wraps input validation and application to input data into a single call

In [None]:
# wraps input validation and application to input data into a single call
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

In [5]:
class MIADataset(torch.utils.data.Dataset): # create a custom dataset for neurips mia model
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

Using Pytorch trainer, we want to train our model. t-5 small and transformers all around rely solely on this training data so this is definatley the most important aspect of the code. Doing so will play a massive role in how we analyze this single cell data. 

In [None]:
#  assuming we want to use trainer in leiu of custom pytorch trainer
# need to change training args based on raz input on the model
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)

model = AutoModelWithLMHead.from_pretrained("t5-small")

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

trainer.train()

# Baseline Models given by NeurIPS

a few statistical metrics given to us in the NeurIPS competition that shows how are algorithm performs. <br>
The tests include:<br>
**rmse** - **r**oot **m**ean **s**quare **e**rror<br>
**baseline_linear** - linear regressor test<br>
**baseline_mean** - mean test

In [13]:
def calculate_rmse(true_test_mod2, pred_test_mod2):
    if pred_test_mod2.var["feature_types"][0] == "GEX":
        return  mean_squared_error(true_test_mod2.layers["log_norm"].toarray(), pred_test_mod2.X, squared=False)
    else:
        raise NotImplementedError("Only set up to calculate RMSE for GEX data")

In [14]:
def baseline_linear(input_train_mod1, input_train_mod2, input_test_mod1):
    '''Baseline method training a linear regressor on the input data'''
    input_mod1 = ad.concat(
        {"train": input_train_mod1, "test": input_test_mod1},
        axis=0,
        join="outer",
        label="group",
        fill_value=0,
        index_unique="-", 
    )
    
    # Binarize ATAC 
    if input_train_mod1.var["feature_types"][0] == "ATAC":
        input_mod1.X[input_mod1.X > 1] = 1
    elif input_train_mod2.var["feature_types"][0] == "ATAC":
        input_train_mod2.X[input_mod1.X > 1] = 1
    
    # Do PCA on the input data
    logging.info('Performing dimensionality reduction on modality 1 values...')
    embedder_mod1 = TruncatedSVD(n_components=50)
    mod1_pca = embedder_mod1.fit_transform(input_mod1.X)
    
    logging.info('Performing dimensionality reduction on modality 2 values...')
    embedder_mod2 = TruncatedSVD(n_components=50)
    mod2_pca = embedder_mod2.fit_transform(input_train_mod2.layers["log_norm"])
    
    # split dimred mod 1 back up for training
    X_train = mod1_pca[input_mod1.obs['group'] == 'train']
    X_test = mod1_pca[input_mod1.obs['group'] == 'test']
    y_train = mod2_pca
    
    assert len(X_train) + len(X_test) == len(mod1_pca)
    
    logging.info('Running Linear regression...')
    
    reg = LinearRegression()
    
    # Train the model on the PCA reduced modality 1 and 2 data
    reg.fit(X_train, y_train)
    y_pred = reg.predict(X_test)
    
    # Project the predictions back to the modality 2 feature space
    y_pred = y_pred @ embedder_mod2.components_
    
    pred_test_mod2 = ad.AnnData(
        X = y_pred,
        obs = input_test_mod1.obs,
        var = input_train_mod2.var,
    
    )
    
    # Add the name of the method to the result
    pred_test_mod2.uns["method"] = "linear"
    
    return pred_test_mod2

In [15]:
def baseline_mean(input_train_mod1, input_train_mod2, input_test_mod1):
    '''Dummy method that predicts mean(input_train_mod2) for all cells'''
    logging.info('Calculate mean of the training data modality 2...')
    y_pred = np.repeat(input_train_mod2.layers["log_norm"].mean(axis=0).reshape(-1,1).T, input_test_mod1.shape[0], axis=0)
    
    # Prepare the ouput data object
    pred_test_mod2 = ad.AnnData(
        X=y_pred,
        obs=input_test_mod1.obs,
        var=input_train_mod2.var,
    )
    
    pred_test_mod2.uns["method"] = "mean"

    return pred_test_mod2