In [54]:
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 5)
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import importlib
sys.path.append('../src/')
sys.path.append('../data/')
import model as mod
importlib.reload(mod)
#from data_utils import generate_synthethic, transform_and_normalize, CCCscore
import load_rosmap as lr
importlib.reload(lr)
import data_utils as du
importlib.reload(du)
from sklearn.model_selection import train_test_split
import torch
import pandas as pd
import os

In [41]:
## lets use the human brain cortex dataset from the original manuscript
scRNA, bulkRNA = lr.load_and_filter_hbc_rosmap(genes_cutoff=3000)

## define number of samples
nsamples = 5000
train_size, test_size = 0.8, 0.2

In [51]:
scRNA

Unnamed: 0,MALAT1,NEAT1,...,P2RX7,SYNJ1
Endotelial,18.0,18.0,...,0.0,0.0
Microglia,19.0,11.0,...,1.0,0.0
Endotelial,64.0,20.0,...,0.0,0.0
Neuron,39.0,7.0,...,0.0,0.0
Endotelial,20.0,7.0,...,0.0,0.0
...,...,...,...,...,...
Astrocyte,76.0,39.0,...,0.0,0.0
Oligodendrocyte,121.0,18.0,...,0.0,0.0
Endotelial,18.0,11.0,...,0.0,0.0
Endotelial,46.0,12.0,...,0.0,0.0


In [42]:
#split into train and test
scRNA_train, scRNA_test = train_test_split(scRNA.copy(), stratify=scRNA.index, test_size = 0.2, random_state= 42)

# create pseudobulk for train
xtrain, ytrain, celltypes = du.generate_synthethic(scRNA_train, nsamples = nsamples * train_size)
xtest, ytest, _ = du.generate_synthethic(scRNA_test, nsamples = nsamples * test_size)

## transform and normalize
xtrain, xtest, xbulk = du.transform_and_normalize(xtrain, xtest, bulkRNA.values)

## convert to torch
xtrain, ytrain, xtest, ytest, xbulk = du.convert_to_float_tensors(xtrain, ytrain, xtest, ytest, xbulk)

simulating bulk: 100%|██████████| 4030/4030 [00:08<00:00, 456.26it/s]
simulating bulk: 100%|██████████| 1023/1023 [00:00<00:00, 1879.80it/s]


In [43]:
print(xtrain.shape) #4088 (~5000*0.8) samples x 3000 genes (most variant ones)
print(xtest.shape) #511 (~5000*0.2) samples x 3000 genes (most variant ones)
print(ytrain.shape) #proportions associated to xtrain
print(ytest.shape) #proportions associated to xtest

torch.Size([4030, 3000])
torch.Size([1023, 3000])
torch.Size([4030, 5])
torch.Size([1023, 5])


In [44]:
## define epochs and init sweetwater object
epochs = round(30000/(xtrain.shape[0]/256))
sw = mod.SweetWater(data = (xtrain, ytrain, xtest, ytest), 
                bulkrna = xbulk,
                name = 'Human Brain Cortex', verbose = True, 
                lr = 0.00001, batch_size = 256, epochs = epochs)

# train
sw.run()

Stablishing Early Stopping with patience 10
Stablishing Early Stopping with patience 10
Stablishing Early Stopping with patience 50


P1: Train MSE is: 0.000359, Test MSE is 0.001379:  23%|██▎       | 436/1906 [01:19<04:27,  5.49it/s]


Early stopping condition achieved


P2: Train MSE is: 0.000454, Test MSE is: 0.000666:   3%|▎         | 56/1906 [00:08<04:56,  6.24it/s]


Early stopping condition achieved


P3: Train MSE 0.000319, test MSE 0.001089, Train R2 0.9961, Test R2 0.9867:  25%|██▌       | 483/1906 [01:20<03:58,  5.98it/s]

Early stopping condition achieved





In [53]:
## we can now infer the cell type proportions of our bulkRNA samples
ypredbulkrna = sw.aemodel(xbulk.to(sw.device), mode = 'phase3')
print(pd.DataFrame(ypredbulkrna.detach().cpu(), columns = celltypes))

    Astrocyte  Endotelial  Microglia    Neuron  Oligodendrocyte
0    0.457273    0.000157   0.007229  0.319406         0.215935
1    0.251822    0.000101   0.008102  0.371767         0.368208
2    0.263427    0.000399   0.005669  0.207004         0.523502
3    0.558424    0.000551   0.015134  0.212886         0.213006
4    0.417860    0.007512   0.005068  0.018189         0.551371
..        ...         ...        ...       ...              ...
44   0.290757    0.000155   0.008858  0.347381         0.352849
45   0.570509    0.001234   0.007862  0.034796         0.385599
46   0.367690    0.007774   0.004510  0.031133         0.588893
47   0.475678    0.000289   0.008587  0.186441         0.329005
48   0.448945    0.000333   0.012936  0.183593         0.354193

[49 rows x 5 columns]


In [56]:
## save the model to perform interpretability (see interpretability.ipynb)
torch.save(sw.aemodel.state_dict(), os.path.join('../','data','model_rosmap_weights_3000_genes.pt'))