In [None]:
import scanpy as sc

In [None]:
#pip install torch
import os
import time
import datetime
import numpy as np
from collections import OrderedDict
import tqdm
import argparse
import pickle
import warnings

import scanpy as sc

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys

import sys
#sys.path.append(prj_path)
from drscax.scripts import models as tcgamodels
from drscax.scripts import eval_ as tcgaeval
from drscax.scripts import data as tcgadata
from drscax.scripts import train as tcgatrain

In [None]:
job_name='v062_woimmune_bst8layer50k_221012_165814'
data_version='cancer_only'
block='bst8layer50k'
approach='2'
model='cAE'
beta=0.0
nocondition=True
binarizeinput=False
initial_lr=0.001
batch_size=1024
nolrscheduler=False
layer_norm=True
inject_c1_eachlayer=False

In [None]:
adata = sc.read('PublicationPage/tcga_canceronly_top50klogtfidf_221011.h5ad')
criterion = tcgatrain.VAELoss(beta=beta, reconstruction='cont') if model=='cVAE' else nn.MSELoss(reduction='sum')

In [None]:
job_name='v062_woimmune_bst8layer50k_221012_165814'
data_version='cancer_only'
block='bst8layer50k'
approach='2'
model='cAE'
beta=0.0
nocondition=True
binarizeinput=False
initial_lr=0.001
batch_size=1024
nolrscheduler=False
layer_norm=True
inject_c1_eachlayer=False

In [None]:
if block == 'bst8layer50k':
    block = [50000, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]


In [None]:
if model == 'cAE':

    net = tcgamodels.cAE_v05(
        layer_io=block,
        layer_norm=layer_norm,
        n_c1_class=0 if nocondition else len(adata.obs['batch'].unique()),
        c1_embed_dim=0 if nocondition else 8,
        inject_c1_eachlayer=inject_c1_eachlayer,
        sigmoid_out=binarizeinput,
        return_latent=True,
    )

else:
    print('\nInvalid model specified [choose one of cVAE or cAE]. Exiting.')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    torch.cuda.empty_cache()

net = net.to(device)
print(net)

In [None]:
state_dict = torch.load("PublicationPage/v062_woimmune_bst8layer50k_221012_165814.pt")

In [None]:
net.load_state_dict(state_dict['net'])

In [None]:
def viz_umap(mat, barcodes, md, use_pca=False, plot_file=None, include_leiden=True, optional_color='tissue'):
    from sklearn.decomposition import PCA
    # make adata for approach 1
    tdata = sc.AnnData(X=mat.numpy(), obs=md.loc[barcodes, :])
    if 'tissue' not in tdata.obs.columns:
        tdata.obs['tissue'] = tdata.obs['Sample'].apply(lambda s: s.split('_')[1])
    if 'n_idx' not in tdata.obs.columns:
        # add a numerical index
        tdata.obs['n_idx'] = np.arange(tdata.shape[0])
    if use_pca:
        pca = PCA(n_components=30)
        pca.fit(tdata.X)
        tdata.obsm['X_pca'] = pca.transform(tdata.X)
        sc.pp.neighbors(tdata, n_pcs=30)
    else:
        sc.pp.neighbors(tdata, n_pcs=0, use_rep=None) # use .X
    if include_leiden:
        sc.tl.leiden(tdata)
    sc.tl.umap(tdata)
    if plot_file is not None:
        sc.settings.figdir, save_name = os.path.split(plot_file)
    else:
        save_name = None
    if include_leiden:
        colors = ['batch', optional_color, 'leiden']
    else:
        colors = ['batch', optional_color]
    sc.pl.umap(tdata, color=colors,
               save=save_name)
    return tdata

In [None]:
net.eval()

In [None]:
N = adata.shape[0]
barcodes = []
chunk_size=1024
net.eval()
count = 0
for i in tqdm.tqdm(range(0, N, chunk_size)):
        x = torch.tensor(adata.X[i:i+chunk_size].toarray(), dtype=torch.float32)
        x = x.to(device)
        n = x.shape[0]
        #
        xhat, z= net(x)
        #
        barcodes += adata.obs.iloc[i:i+chunk_size].index.to_list()
        #
        if i==0:
            X = torch.empty(N, x.shape[1])
            Xhat = torch.empty(N, xhat.shape[1])
            Z = torch.empty(N, z.shape[1])
        X[count:count+n] = x.detach().cpu()
        Xhat[count:count+n] = xhat.detach().cpu()
        Z[count:count+n] = z.detach().cpu()
        count += n
        


In [None]:

tdata_Z_PCA = viz_umap(Z, barcodes, adata.obs, use_pca=True,  include_leiden=False)
        



In [None]:
np.save("Z_PCA.npy", tdata_Z_PCA.obsm['X_pca'])

In [None]:
np.save("Z_PCA_barcode_index.npy", tdata_Z.obs.index.tolist())
