# Transformer model for predicting modalities in scRNA-seq

**Authors**<br>Vedu Mallela: GiwoTech, vedu.mallela@gmail.com<br>Simon Lee: UC Santa Cruz, siaulee@ucsc.edu

# Goal of the code

**TODO: explain algorithm**

# Libraries 

Import all files and modules for this competition<br>
*below will provide documentation of the following libraries*<br>
<br>
**scanpy** (**s**ingle **c**ell **an**alysis in **Py**thon) - https://scanpy.readthedocs.io/en/stable/ <br>
**anndata** (**ann**otated **data**) - https://anndata.readthedocs.io/en/latest/ <br>
**matplotlib** - https://matplotlib.org/ <br>
**numpy** - https://numpy.org/doc/stable/ <br>
**pandas** - https://pandas.pydata.org/ <br>
**logging** - https://docs.python.org/3/howto/logging.html <br>
**sklearn** - https://scikit-learn.org/stable/ <br>
<br>
*code begins here*

In [1]:
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import logging

from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# LOAD DATA

In [2]:
path = '/tmp/public/multiome/'
outpath ='../out/'
adata_gex = ad.read_h5ad(path + "multiome_gex_processed_training.h5ad")
adata_atac = ad.read_h5ad(path + "multiome_atac_processed_training.h5ad")

After successfully loading in the data, we can try to begin plotting the batch for the **Assay for Transposase-Accessible Chromatin using sequencing** (ATAC-seq) and **Gene Expression** (GEX) data on the umap interface. These umap projections will be saved to GEX.pdf and ATAC.pdf

In [3]:
#sc.tl.pca(adata_gex)
#sc.pl.umap(adata_gex, color=['batch'],save='_GEX', title='GEX umap Display')
#sc.tl.pca(adata_atac)
#sc.pl.umap(adata_atac, color=['batch'], layer='counts', save='_ATAC', title='ATAC umap Display')

Next we are going to check out all the indivdual cell types occuring in both the ATAC and GEX data. This way we can see all the types of cells from this dataset provided.

In [4]:
#sc.tl.pca(adata_gex)
#sc.pl.umap(adata_gex, color='cell_type',save='_GEX_ct', title='GEX Cell Type umap')
#sc.tl.pca(adata_atac)
#sc.pl.umap(adata_atac, color='cell_type',save='_ATAC_ct', title='ATAC Cell Type umap')

# Filter out our Data <1%

In [5]:
# filter out the data
# Convert anndata objects to dataframes and filter.  
# Genes that show up in < 1% cells are dropped.
# Atac seq data that shows up in < 1% cells are dropped 

gex_df = adata_gex.to_df()
atac_df = adata_atac.to_df()

gex_df_col = np.array(gex_df.columns.values)
atac_df_col = np.array(atac_df.columns.values)

# filter out the data
# Convert anndata objects to dataframes and filter.  
# Genes that show up in < 1% cells are dropped.
# Atac seq data that shows up in < 1% cells are dropped 

gex_df = adata_gex.to_df()
atac_df = adata_atac.to_df()

In [11]:
gex_df_row = list(gex_df.index.values)
atac_df_row = list(atac_df.index.values)

for row in gex_df_row: # delete the differences rows between atac and gex
    if row not in atac_df_row:
        gex_df_row.remove(row)
        
# initialize dictionary to store the results
gex_dictionary = {key: None for key in gex_df_row}

print(len(gex_dictionary))
print(len(gex_df))

22463
22463


In [12]:
mask = gex_df>0
total_cells = gex_df.shape[0]
maskdf = mask.sum(axis=0)/total_cells*100 <=1

gex_feature_drop  = list(maskdf.loc[maskdf==True].index.values)

mask = atac_df>0
maskdf = mask.sum(axis=0)/total_cells <=0.01

atac_feature_drop  = list(maskdf.loc[maskdf==True].index.values)

gex_ = gex_df.drop(columns=gex_feature_drop)
atac_ = atac_df.drop(columns=atac_feature_drop)

print('Filtered data set')
print('GEX data: Total cells=' + str(gex_.shape[0]) + ', Number features=' + str(gex_.shape[1]))
print('ATAC data: Total cells=' + str(atac_.shape[0]) + ', Number Features=' + str(atac_.shape[1]))

Filtered data set
GEX data: Total cells=22463, Number features=12160
ATAC data: Total cells=22463, Number Features=53020


# Filtered out Gene Expression and ATAC data

In [13]:
gex_

Unnamed: 0,AL627309.5,LINC01409,LINC01128,NOC2L,ISG15,C1orf159,SDF4,B3GALT6,UBE2J2,ACAP3,...,MT-ATP8,MT-ATP6,MT-CO3,MT-ND3,MT-ND4L,MT-ND4,MT-ND5,MT-ND6,MT-CYB,AL592183.1
TAGTTGTCACCCTCAC-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,4.410295,0.000000,4.410295,0.000000,0.000000,4.410295,0.0
CTATGGCCATAACGGG-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,2.194758,0.0,...,0.000000,0.000000,2.194758,0.000000,0.000000,0.000000,0.000000,0.000000,13.168547,0.0
CCGCACACAGGTTAAA-1-s1d1,0.410619,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.410619,0.821238,3.284951,0.410619,0.410619,0.821238,1.231857,0.410619,3.284951,0.0
TCATTTGGTAATGGAA-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,2.879966,11.519863,0.000000,0.000000,0.000000,5.759931,0.000000,0.000000,0.0
ACCACATAGGTGTCCA-1-s1d1,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,3.743880,13.103581,1.871940,0.000000,3.743880,1.871940,0.000000,9.359701,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TAGTAAGCAACTAGGG-8-s3d6,0.000000,0.0,0.0,0.0,1.915288,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
TGGTCCTTCGGCTAGC-8-s3d6,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
CGCTTGCGTTGTTGGA-8-s3d6,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,3.427911,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
ACCCTCCCAGCCAGTT-8-s3d6,0.000000,0.0,0.0,0.0,0.000000,0.0,0.0,0.0,0.000000,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0


In [14]:
atac_

Unnamed: 0,chr1-181117-181803,chr1-629497-630394,chr1-633515-634474,chr1-778276-779191,chr1-816868-817761,chr1-827067-827948,chr1-842497-843414,chr1-869472-870377,chr1-904343-905196,chr1-906441-907357,...,GL000205.2-88673-89483,GL000205.2-140307-141166,GL000195.1-30407-31261,GL000195.1-32211-33062,GL000219.1-39933-40839,GL000219.1-42172-43054,GL000219.1-44703-45584,GL000219.1-45726-46450,GL000219.1-99257-100160,KI270713.1-21434-22336
TAGTTGTCACCCTCAC-1-s1d1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
CTATGGCCATAACGGG-1-s1d1,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CCGCACACAGGTTAAA-1-s1d1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TCATTTGGTAATGGAA-1-s1d1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ACCACATAGGTGTCCA-1-s1d1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TAGTAAGCAACTAGGG-8-s3d6,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TGGTCCTTCGGCTAGC-8-s3d6,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
CGCTTGCGTTGTTGGA-8-s3d6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0
ACCCTCCCAGCCAGTT-8-s3d6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# import Transformer libraries

In [2]:
from transformers import AutoTokenizer, AutoModelWithLMHead, T5ForConditionalGeneration
from pathlib import Path
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time

# Model Architecture

In [3]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask,
                            tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

In [4]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

# Encoder

In [5]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [6]:
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [7]:
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


In [8]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

In [9]:
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

# Decoder

In [10]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [11]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
 
    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [12]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


# Attention

In [13]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


In [14]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

# Position-wise Feed Forward-Networks

In [15]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

# Embeddings and Softmax

In [16]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Positional Encoding

In [17]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

# Model Construction

In [18]:
def make_model(src_vocab, tgt_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))
    
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)
    return model

# Reshape Data for T-5 Input

In [16]:
import numpy as np
import sklearn
#convert anndata to numpy array
gex_numpy_matrix = np.array(gex_)
atac_numpy_matrix = np.array(atac_)
print(type(atac_numpy_matrix))
print(gex_numpy_matrix.shape)

#binarize the gex data
genex_ = sklearn.preprocessing.binarize(gex_numpy_matrix, threshold=1.0, copy=False)

<class 'numpy.ndarray'>
(22463, 12160)


In [None]:
# if gex or atac data is present or 1 in binary obtain the name and replace binary value with name
# print(gex_.shape)
# print(genex_.shape)
gex_[:] = genex_ # replace the gex data with the binary data

gex_df_col = list(gex_df.columns.values) # get the column values

# print(gex_dictionary)

for key in gex_dictionary:
    active_genes = ''
    for cell in range(len(gex_.loc[key])):
        if gex_.loc[key][cell] == 1:
            active_genes += gex_df_col[cell] + ' '

    gex_dictionary[key] = active_genes

print(gex_dictionary)

# print(len(gex_dictionary))
#store the atac seq and gex of s1d1 for example in a .txt file 
#ex. {'TAGGTA': 'chr1-633515-634474 chr1-633525-634474 ', 'en': 'That is good.'}
#call wrapper function 
#class TextLineTask(FunctionTask): found in text-to-text-transfer-transformer/t5/data/dataset_providers.py
#call t5 small model

In [None]:
genex_
#looks like the numpy array works


In [None]:
# read the sequence and split it into chunks
def read_seq_split(split_dir): 
    split_dir = (split_dir)
    texts = []
    labels = []
    for label_dir in ["gex", "atac"]: # for each label
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())

    return texts, labels

In [None]:
from transformers import AutoTokenizer, AutoModelWithLMHead, T5ForConditionalGeneration, T5Tokenizer
from pathlib import Path
from sklearn.model_selection import train_test_split
import torch
#gex_df = adata_gex.to_df()
#X = gex_df.drop(['target'],axis=1).values   # independant features
#y = gex_df['target'].values                 # dependant variable
#train_texts, val_texts, train_labels, val_labels = train_test_split(x, y, train_texts, train_labels, test_size=.2)
#train_text, train_labels = read_seq_split(gex_df)
#train_text =  

tokenizer = T5Tokenizer.from_pretrained("t5-small")
train_encodings = tokenizer(adata_gex, truncation=True, padding=True)
#val_encodings = tokenizer(val_texts, truncation=True, padding=True)
#test_encodings = tokenizer(test_texts, truncation=True, padding=True)

In [None]:
#train_texts, train_labels = read_seq_split(adata_gex) 
#test_texts, test_labels = read_seq_split('figures/test')

In [None]:
adata_atac.obs

We are now going to print out the number of observations and features of our GEX and ATAC-seq data. 

Few things to note before we proceed:



In [None]:
print(f"The GEX data has {adata_gex.n_obs} observations and {adata_gex.n_vars} features.")
print(f"The ATAC data has {adata_atac.n_obs} observations and {adata_atac.n_vars} features.")

# TRANSFORMER

The Transformer T-5 small model will take in a custom dataset. This model relies solely on training therefore it is important that we have the proper pretraining before turning in this method for single cell sequencing analysis.

In [None]:
# read the sequence and split it into chunks
def read_seq_split(split_dir): 
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["", ""]: # for each label
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())

    return texts, labels

Loading up the data so the sequences can be read

Train our transformer using this the train_test_split() function. This wraps input validation and application to input data into a single call

In [None]:
# wraps input validation and application to input data into a single call
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

In [None]:
class MIADataset(torch.utils.data.Dataset): # create a custom dataset for neurips mia model
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# Training

Using Pytorch trainer, we want to train our model. t-5 small and transformers all around rely solely on this training data so this is definatley the most important aspect of the code. Doing so will play a massive role in how we analyze this single cell data. 

In [None]:
#  assuming we want to use trainer in leiu of custom pytorch trainer
# need to change training args based on raz input on the model
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
)

model = AutoModelWithLMHead.from_pretrained("t5-small")

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

trainer.train()

In [19]:
class Batch:
    "Object for holding a batch of data with mask during training."
    def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = \
                self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum()
    
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(
            subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask

In [None]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, 
                            batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

# Baseline Models given by NeurIPS

a few statistical metrics given to us in the NeurIPS competition that shows how are algorithm performs. <br>
The tests include:<br>
**rmse** - **r**oot **m**ean **s**quare **e**rror<br>
**baseline_linear** - linear regressor test<br>
**baseline_mean** - mean test

In [None]:
def calculate_rmse(true_test_mod2, pred_test_mod2):
    if pred_test_mod2.var["feature_types"][0] == "GEX":
        return  mean_squared_error(true_test_mod2.layers["log_norm"].toarray(), pred_test_mod2.X, squared=False)
    else:
        raise NotImplementedError("Only set up to calculate RMSE for GEX data")

In [None]:
def baseline_linear(input_train_mod1, input_train_mod2, input_test_mod1):
    '''Baseline method training a linear regressor on the input data'''
    input_mod1 = ad.concat(
        {"train": input_train_mod1, "test": input_test_mod1},
        axis=0,
        join="outer",
        label="group",
        fill_value=0,
        index_unique="-", 
    )
    
    # Binarize ATAC 
    if input_train_mod1.var["feature_types"][0] == "ATAC":
        input_mod1.X[input_mod1.X > 1] = 1
    elif input_train_mod2.var["feature_types"][0] == "ATAC":
        input_train_mod2.X[input_mod1.X > 1] = 1
    
    # Do PCA on the input data
    logging.info('Performing dimensionality reduction on modality 1 values...')
    embedder_mod1 = TruncatedSVD(n_components=50)
    mod1_pca = embedder_mod1.fit_transform(input_mod1.X)
    
    logging.info('Performing dimensionality reduction on modality 2 values...')
    embedder_mod2 = TruncatedSVD(n_components=50)
    mod2_pca = embedder_mod2.fit_transform(input_train_mod2.layers["log_norm"])
    
    # split dimred mod 1 back up for training
    X_train = mod1_pca[input_mod1.obs['group'] == 'train']
    X_test = mod1_pca[input_mod1.obs['group'] == 'test']
    y_train = mod2_pca
    
    assert len(X_train) + len(X_test) == len(mod1_pca)
    
    logging.info('Running Linear regression...')
    
    reg = LinearRegression()
    
    # Train the model on the PCA reduced modality 1 and 2 data
    reg.fit(X_train, y_train)
    y_pred = reg.predict(X_test)
    
    # Project the predictions back to the modality 2 feature space
    y_pred = y_pred @ embedder_mod2.components_
    
    pred_test_mod2 = ad.AnnData(
        X = y_pred,
        obs = input_test_mod1.obs,
        var = input_train_mod2.var,
    
    )
    
    # Add the name of the method to the result
    pred_test_mod2.uns["method"] = "linear"
    
    return pred_test_mod2

In [None]:
def baseline_mean(input_train_mod1, input_train_mod2, input_test_mod1):
    '''Dummy method that predicts mean(input_train_mod2) for all cells'''
    logging.info('Calculate mean of the training data modality 2...')
    y_pred = np.repeat(input_train_mod2.layers["log_norm"].mean(axis=0).reshape(-1,1).T, input_test_mod1.shape[0], axis=0)
    
    # Prepare the ouput data object
    pred_test_mod2 = ad.AnnData(
        X=y_pred,
        obs=input_test_mod1.obs,
        var=input_train_mod2.var,
    )
    
    pred_test_mod2.uns["method"] = "mean"

    return pred_test_mod2