In [None]:
import os
import scanpy
import glob
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl

In [None]:
DATA_DIR = './data/train'
TEMP_DIR = './reference_data/temp_store/'
file_paths = glob.glob(os.path.join(DATA_DIR, "*.h5ad"))

In [None]:
# Compile: 
#   1) List of all genes and frequency in files
#   2) List of all barcodes

gene_dict = {}
all_barcodes = []

for f in file_paths:
    adata_backed = scanpy.read(f,backed='r')
    gene_list_ = adata_backed.var['gene'].values
    for g in gene_list_:
        if g in gene_dict:
            gene_dict[g] += 1
        else:
            gene_dict[g] = 1

    all_barcodes += list(adata_backed.obs.index.values)


In [None]:
freqs = np.asarray(list(gene_dict.values()))
plt.hist(freqs)
plt.show()

In [None]:
import os
import scanpy
import glob
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl

DATA_DIR = './data/train'
TEMP_DIR = './reference_data/temp_store/'
file_paths = glob.glob(os.path.join(DATA_DIR, "*.h5ad"))

# Compile: 
#   1) List of all genes and frequency in files
#   2) List of all barcodes

gene_dict = {}
all_barcodes = []

for f in file_paths:
    adata_backed = scanpy.read(f,backed='r')
    gene_list_ = adata_backed.var['gene'].values
    for g in gene_list_:
        if g in gene_dict:
            gene_dict[g] += 1
        else:
            gene_dict[g] = 1

    all_barcodes += list(adata_backed.obs.index.values)

# Specify reference gene list

reference_gene_list = [k for k,v in gene_dict.items() if v>5]

# Save reference gene list

with open(TEMP_DIR + 'reference_genes.list', 'wb') as handle:
    pkl.dump(reference_gene_list, handle, protocol=pkl.HIGHEST_PROTOCOL)
    
def log_mean(g, X):
    
    vals = X[g.values,:]
    log_counts = np.log(vals.sum(axis=1))
    local_mean = np.mean(log_counts).astype(np.float32)
    return local_mean

def log_var(g, X):
    
    vals = X[g.values,:]
    log_counts = np.log(vals.sum(axis=1))
    local_var = np.var(log_counts).astype(np.float32)
    return local_var

batch_keys = ['sample_name','study_name']

b_size = 20000 # Block size parameter for random chunking of adatas
n_cells = len(all_barcodes)
block_n = n_cells//b_size

random_inds = np.random.permutation(np.arange(n_cells))
block_mapping = {all_barcodes[random_inds[i]]:i//b_size for i in range(n_cells)}

print('Splitting data into :',block_n, ' blocks')
# Create temp folder for saving adata blocks

if not os.path.exists(TEMP_DIR):
    os.makedirs(TEMP_DIR)

# Chunk adatas into blocks - load in n adata (nad) at once to save on file write number

nad = 3

start = time.time()

batch_counters = {b:0 for b in batch_keys}
split_paths = []
batch_dict = {}

for ii in range(0,len(file_paths),nad):
    adatas = []
    
    for jj in range(nad):
        if ii+jj < len(file_paths):
            
            f_path = file_paths[ii+jj]
            
            print('Reading File: ', ii+jj, ' out of ', len(file_paths), ' file name is ', f_path)            
            adata_ = scanpy.read(f_path)
            
            if not adata_.X.dtype == np.float32:
                adata_.X = adata_.X.astype(np.float32)

            ref = scanpy.AnnData(X=np.zeros((1,len(reference_gene_list)),dtype=np.float32),var={'gene':reference_gene_list})
            ref.var = ref.var.set_index(ref.var['gene'])
            adata_ = scanpy.concat([ref,adata_],join='outer')
            adata_ = adata_[:,reference_gene_list]
            
            #Filter cells with very few gene reads
            
            gene_counts = adata_.X.getnnz(axis=1)
            mask = gene_counts > 300
            adata_ = adata_[mask,:].copy()
            
            #Batch calcs - assign batch ID that's unique across all datasets
            
            for i,b in enumerate(batch_keys):
                batch_id = "batch_" + str(i+1)
                codes = adata_.obs[b].astype("category").cat.codes.astype(int)
                adata_.obs[batch_id] = codes + batch_counters[b]
                batch_counters[b] += codes.max() + 1
                batch_dict[batch_id] = codes.max() + 1
                
            #Add local_l_mean_key and local_l_var_key to adata.obs
            
            adata_.obs['int_index'] = list(range(adata_.shape[0]))
            
            for i in range(len(batch_keys)):

                header_m = "l_mean_batch_" + str(i+1)
                adata_.obs[header_m] = adata_.obs.groupby(batch_keys[i])["int_index"].transform(log_mean, adata_.X)
                header_v = "l_var_batch_" + str(i+1)
                adata_.obs[header_v] = adata_.obs.groupby(batch_keys[i])["int_index"].transform(log_var, adata_.X)
                    
            adatas.append(adata_)
            
    adata_chunk = scanpy.concat(adatas, join="inner", index_unique=None)
    adata_chunk.obs['block'] = adata_chunk.obs['barcode'].apply(lambda x : block_mapping[x])

    print('Writing split : ', ii//nad)
    for i in range(block_n):
        
        split_path = TEMP_DIR + 'chunk_' + str(ii//nad) + '_split_' + str(i) + '.h5ad'
        
        adata_split = adata_chunk[adata_chunk.obs['block']==i,:].copy()
        adata_split.write(split_path)
        split_paths.append(split_path)

with open(TEMP_DIR+'bdict.dict', "wb") as fh:
    pkl.dump([batch_dict, batch_keys], fh)
    
print('Consolidating split blocks...')

for i in range(block_n):
    adatas = []
    for split_path in split_paths: 
        if split_path[-6] == str(i):
            adata_ = scanpy.read(split_path)
            adatas.append(adata_)
            os.remove(split_path)
            
    adata_block = scanpy.concat(adatas, join="inner", index_unique=None)
    adata_block.write(TEMP_DIR+'adata_block_'+str(i)+'.h5ad')

In [None]:
DATA_DIR = './data/train'
TEMP_DIR = './reference_data/temp_store/'
file_paths = glob.glob(os.path.join(TEMP_DIR, "*.h5ad"))

In [None]:
file_paths

In [None]:
import pandas as pd

obs = []

for i in range(5):

    obs.append(scanpy.read(file_paths[i],backed='r').obs)

obs = pd.concat(obs)
obs