In [1]:
import pandas as pd
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 5)
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import importlib
sys.path.append('../src/')
sys.path.append('../data/')
import model as mod
importlib.reload(mod)
import data_utils as du
importlib.reload(du)
from sklearn.model_selection import train_test_split
import torch
import os



In [2]:
## lets use the human brain cortex dataset from the original manuscript, but only with the first 3000 most variant genes
scRNA = pd.read_csv(os.path.join('..','data','scrna_reduced_3000.tsv'),sep='\t', index_col=0)
bulkRNA = pd.read_csv(os.path.join('..','data','bulkrna_reduced_3000.tsv'),sep='\t', index_col=0)

## define number of samples
nsamples = 5000
train_size, test_size = 0.8, 0.2

In [3]:
scRNA

Unnamed: 0,MALAT1,NEAT1,...,P2RX7,SYNJ1
Endotelial,18.0,18.0,...,0.0,0.0
Microglia,19.0,11.0,...,1.0,0.0
Endotelial,64.0,20.0,...,0.0,0.0
Neuron,39.0,7.0,...,0.0,0.0
Endotelial,20.0,7.0,...,0.0,0.0
...,...,...,...,...,...
Astrocyte,76.0,39.0,...,0.0,0.0
Oligodendrocyte,121.0,18.0,...,0.0,0.0
Endotelial,18.0,11.0,...,0.0,0.0
Endotelial,46.0,12.0,...,0.0,0.0


In [4]:
#split into train and test
scRNA_train, scRNA_test = train_test_split(scRNA.copy(), stratify=scRNA.index, test_size = 0.2, random_state= 42)

# create pseudobulk for train
xtrain, ytrain, celltypes = du.generate_synthethic(scRNA_train, nsamples = nsamples * train_size)
xtest, ytest, _ = du.generate_synthethic(scRNA_test, nsamples = nsamples * test_size)

## transform and normalize
xtrain, xtest, xbulk = du.transform_and_normalize(xtrain, xtest, bulkRNA.values)

## convert to torch
xtrain, ytrain, xtest, ytest, xbulk = du.convert_to_float_tensors(xtrain, ytrain, xtest, ytest, xbulk)

simulating bulk: 100%|██████████| 4030/4030 [00:08<00:00, 458.60it/s]
simulating bulk: 100%|██████████| 1023/1023 [00:00<00:00, 1660.36it/s]


In [5]:
print(xtrain.shape) #4088 (~5000*0.8) samples x 3000 genes (most variant ones)
print(xtest.shape) #511 (~5000*0.2) samples x 3000 genes (most variant ones)
print(ytrain.shape) #proportions associated to xtrain
print(ytest.shape) #proportions associated to xtest

torch.Size([4030, 3000])
torch.Size([1023, 3000])
torch.Size([4030, 5])
torch.Size([1023, 5])


In [6]:
## define epochs and init sweetwater object
epochs = round(30000/(xtrain.shape[0]/256))
sw = mod.SweetWater(data = (xtrain, ytrain, xtest, ytest), 
                bulkrna = xbulk,
                name = 'Human Brain Cortex', verbose = True, 
                lr = 0.00001, batch_size = 256, epochs = epochs)

# train
sw.run()

Stablishing Early Stopping with patience 10
Stablishing Early Stopping with patience 10
Stablishing Early Stopping with patience 50


P1: Train MSE is: 0.00035, Test MSE is 0.001397:  23%|██▎       | 432/1906 [01:45<05:58,  4.11it/s] 


Early stopping condition achieved


P2: Train MSE is: 0.000451, Test MSE is: 0.00066:   3%|▎         | 64/1906 [00:12<05:46,  5.31it/s] 


Early stopping condition achieved


P3: Train MSE 0.00024, test MSE 0.001023, Train R2 0.9971, Test R2 0.9875:  32%|███▏      | 617/1906 [02:15<04:43,  4.54it/s] 

Early stopping condition achieved





In [7]:
## we can now infer the cell type proportions of our bulkRNA samples
ypredbulkrna = sw.aemodel(xbulk.to(sw.device), mode = 'phase3')
print(pd.DataFrame(ypredbulkrna.detach().cpu(), columns = celltypes))

    Astrocyte  Endotelial  Microglia    Neuron  Oligodendrocyte
0    0.448156    0.000049   0.003208  0.356009         0.192578
1    0.204977    0.000026   0.003653  0.383842         0.407502
2    0.231552    0.000136   0.002030  0.217299         0.548983
3    0.573304    0.000172   0.005909  0.240426         0.180189
4    0.430829    0.005058   0.002496  0.012442         0.549175
..        ...         ...        ...       ...              ...
44   0.245817    0.000043   0.004040  0.369782         0.380318
45   0.587130    0.000564   0.003326  0.026556         0.382425
46   0.359754    0.004363   0.001543  0.023021         0.611319
47   0.461992    0.000093   0.003632  0.200291         0.333991
48   0.442756    0.000110   0.005628  0.188061         0.363445

[49 rows x 5 columns]


In [8]:
## save the model to perform interpretability (see interpretability.ipynb)
torch.save(sw.aemodel.state_dict(), os.path.join('../','data','model_rosmap_weights_3000_genes.pt'))