### Use to make cell type vector representations based on latent space predictions by specified model

In [1]:
# Load packages
import scanpy as sc
from functions import train as trainer
from models import model_encoder as model_encoder
import torch.nn as nn

In [2]:
adata = sc.read('../../data/processed/immune_cells/merged/Oetjen_merged.h5ad', cache=True)

adata.obs["batch"] = adata.obs['patientID']

model = model_encoder.CellType2VecModel(input_dim=2000,
                                              output_dim=100,
                                              drop_out=0.2,
                                              act_layer=nn.ReLU,
                                              norm_layer=nn.BatchNorm1d)

save_path = "cell_type_vector_representation/CellTypeRepresentations.csv"
model_path = 'trained_models/Bone_marrow/Encoder/2000_HVGs_seed_42'

train_env = trainer.train_module(data_path=adata,
                                        pathways_file_path='../../data/processed/pathway_information/all_pathways.json',
                                        num_pathways=300,
                                        pathway_gene_limit=10,
                                        save_model_path=model_path,
                                        HVG=True,
                                        HVGs=2000,
                                        HVG_buckets=1000,
                                        use_HVG_buckets=False,
                                        Scaled=False,
                                        target_key="cell_type",
                                        batch_keys=["batch"],
                                        use_gene2vec_emb=False,
                                        gene2vec_path='../../data/raw/gene2vec_embeddings/gene2vec_dim_200_iter_9_w2v.txt')

predictions = train_env.generate_representation(data_=adata, model=model, model_path=model_path, save_path=save_path)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  ret = ret.dtype.type(ret / rcount)


In [3]:
predictions

Unnamed: 0,Basophils,CD4+ NKT-like cells,CD8+ NKT-like cells,Classical Monocytes,Effector CD4+ T cells,Erythroid-like and erythroid precursor cells,Memory CD4+ T cells,Memory CD8+ T cells,Myeloid Dendritic cells,Naive B cells,Naive CD4+ T cells,Naive CD8+ T cells,Natural killer cells,Neutrophils,Non-classical monocytes,Plasmacytoid Dendritic cells,Platelets,Pro-B cells,Progenitor cells,γδ-T cells
0,-0.179498,1.939742,0.601815,-5.915797,1.062007,-2.697641,0.794474,0.777824,0.736930,-0.504706,-0.702802,0.139358,3.070304,-2.088998,-2.066032,5.587245,-1.002717,-2.315610,-2.103700,0.742763
1,-1.649277,-0.082276,0.615923,5.032358,1.596658,-3.114418,0.750786,-1.242858,1.821720,-0.627049,0.570235,-0.597256,-0.977977,1.936912,5.939540,-1.135905,3.235129,-1.308736,-1.071515,-0.983287
2,-2.313008,2.798322,0.836900,-5.656613,1.527840,-1.653958,2.055003,0.215869,-3.113821,2.138955,1.001686,2.882901,-0.400904,-2.675512,-8.339398,-1.261993,-3.508720,-3.703548,-1.668980,0.059794
3,-0.711427,-4.452841,-0.983694,-1.261437,-0.877589,1.316771,0.173155,0.054376,5.392364,1.895579,-1.633438,0.905843,-2.103499,-0.410585,3.597668,2.450263,-1.946889,-0.543067,2.254851,0.235371
4,-3.118710,1.921815,1.964768,-1.115444,0.582747,-1.902479,0.957529,-0.360685,-6.990749,1.394755,1.709006,2.518027,0.336029,-5.461646,2.144653,-1.260826,-2.199515,-1.818061,-1.286270,-0.076423
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,1.360314,1.264868,-0.698935,7.804600,-1.608112,2.685769,-1.706336,-0.119752,-2.282422,1.087429,-1.176598,-1.485863,2.941271,1.951273,2.940501,1.158736,1.096178,0.606947,-0.601071,-0.283183
96,3.199102,-2.243653,1.661418,-4.833344,-0.799139,5.453993,-1.345021,-1.471891,0.925032,-2.137302,-0.034433,-1.257670,2.221239,-5.716187,5.696217,1.386607,-0.531979,2.752495,-1.239810,1.332157
97,1.762484,-0.991697,-1.076119,0.930751,0.245750,-0.742895,0.285836,1.329533,6.375932,-4.923123,-4.660348,-1.603902,1.632754,1.446519,4.798259,5.471748,-0.516209,-4.842942,1.821637,1.228021
98,0.394461,3.178756,0.146710,4.953727,-1.736486,1.845619,-1.050124,-1.079245,-2.334530,-0.518914,-0.153567,0.059782,-0.411625,-3.100235,-0.074944,-1.220603,-1.864380,0.126450,1.008860,0.227258
