In [1]:
import anndata as ad
import numpy as np
import scipy
import pandas as pd
import torch

### Larry Data Loading and Preprocessing

In [2]:
# original snakemnake file
'''
rule differentiation_make_adata:
    input:
        normed_counts="resources/differentiation/GSM4185642_stateFate_inVitro_normed_counts.mtx.gz",
        gene_names="resources/differentiation/GSM4185642_stateFate_inVitro_gene_names.txt.gz",
        clone_matrix="resources/differentiation/GSM4185642_stateFate_inVitro_clone_matrix.mtx.gz",
        metadata="resources/differentiation/GSM4185642_stateFate_inVitro_metadata.txt.gz",
    output:
        d4d6_h5ad="results/single_cell_differentiation/data/d4_d6_differentiation.h5ad",
        d2_h5ad="results/single_cell_differentiation/data/d2_differentiation.h5ad",
    conda:
        "../envs/TACCO_env.yml"
    resources:
        mem_mb=8000
    log:
        "logs/single_cell_differentiation/differentiation_make_adata.log"
    benchmark:
        "benchmarks/single_cell_differentiation/differentiation_make_adata.tsv"
    script:
        "differentiation_make_adata.py"
'''

'\nrule differentiation_make_adata:\n    input:\n        normed_counts="resources/differentiation/GSM4185642_stateFate_inVitro_normed_counts.mtx.gz",\n        gene_names="resources/differentiation/GSM4185642_stateFate_inVitro_gene_names.txt.gz",\n        clone_matrix="resources/differentiation/GSM4185642_stateFate_inVitro_clone_matrix.mtx.gz",\n        metadata="resources/differentiation/GSM4185642_stateFate_inVitro_metadata.txt.gz",\n    output:\n        d4d6_h5ad="results/single_cell_differentiation/data/d4_d6_differentiation.h5ad",\n        d2_h5ad="results/single_cell_differentiation/data/d2_differentiation.h5ad",\n    conda:\n        "../envs/TACCO_env.yml"\n    resources:\n        mem_mb=8000\n    log:\n        "logs/single_cell_differentiation/differentiation_make_adata.log"\n    benchmark:\n        "benchmarks/single_cell_differentiation/differentiation_make_adata.tsv"\n    script:\n        "differentiation_make_adata.py"\n'

This normalized count matrix reports the number of transcripts (UMIs) for each gene in each cell, after total-counts normalization (i.e. L1 normalization on cells). Rows represent cells and columns represent genes. There are no column or row labels. Gene names and cell metadata are provided in separate files.

In [3]:
normed_counts = "/Users/apple/Desktop/KB/Dataset1/stateFate_inVitro_normed_counts.mtx.gz"  #snakemake.input['normed_counts']
gene_names = "/Users/apple/Desktop/KB/Dataset1/stateFate_inVitro_gene_names.txt.gz" #snakemake.input['gene_names']
clone_matrix = "/Users/apple/Desktop/KB/Dataset1/stateFate_inVitro_clone_matrix.mtx.gz" #snakemake.input['clone_matrix']
metadata = "/Users/apple/Desktop/KB/Dataset1/stateFate_inVitro_metadata.txt.gz" #snakemake.input['metadata']

# output_d4d6_h5ad = "/Users/apple/Desktop/KB/SCSeq_LineageBarcoding2/SCSeq_LineageBarcoding/Larry_test/tacco_output/d4d6_h5ad" #snakemake.output['d4d6_h5ad']
# output_d2_h5ad = "/Users/apple/Desktop/KB/SCSeq_LineageBarcoding2/SCSeq_LineageBarcoding/Larry_test/tacco_output/d2_h5ad" #snakemake.output['d2_h5ad']

In [4]:
# load data
normed_counts_mat = scipy.io.mmread(normed_counts).tocsr()
genes = pd.read_csv(gene_names, sep='\t',header=None).to_numpy().flatten()
clone_mat = scipy.io.mmread(clone_matrix).tocsr()
meta_df = pd.read_csv(metadata, sep='\t')

In [5]:
meta_df.head()

Unnamed: 0,Library,Cell barcode,Time point,Starting population,Cell type annotation,Well,SPRING-x,SPRING-y
0,d6_2_2,GCGTGCAA-AGAAGTTA,6.0,Lin-Kit+Sca1-,Undifferentiated,2,411.496,-96.19
1,d6_2_2,AAGGGACC-CTCGATGC,6.0,Lin-Kit+Sca1-,Undifferentiated,2,-587.462,-306.925
2,d6_2_2,CGTACCGA-AGCGCCTT,6.0,Lin-Kit+Sca1-,Monocyte,2,1429.805,-429.3
3,d6_2_2,CTGAAGGG-AGGAGCTT,6.0,Lin-Kit+Sca1-,Neutrophil,2,1150.028,-2030.369
4,d6_2_2,CCGTAGCT-AGGCAGTT,6.0,Lin-Kit+Sca1-,Undifferentiated,2,-1169.594,362.01


In [6]:
# create full adata
adata = ad.AnnData(normed_counts_mat, obs=meta_df, var=pd.DataFrame(index=genes), dtype=np.float32)



In [7]:
# optimize dtypes
adata.obs['Library'] = adata.obs['Library'].astype('category')
adata.obs['Time point'] = adata.obs['Time point'].astype(int)
adata.obs['Starting population'] = adata.obs['Starting population'].astype('category')
adata.obs['Cell type annotation'] = adata.obs['Cell type annotation'].astype('category')
adata.obs['Well'] = adata.obs['Well'].astype(int)
# assign clone_id
adata.obs['clone_id'] = (clone_mat @ np.arange(1,1+clone_mat.shape[1])) - 1

In [8]:
adata.obs

Unnamed: 0,Library,Cell barcode,Time point,Starting population,Cell type annotation,Well,SPRING-x,SPRING-y,clone_id
0,d6_2_2,GCGTGCAA-AGAAGTTA,6,Lin-Kit+Sca1-,Undifferentiated,2,411.496,-96.190,573
1,d6_2_2,AAGGGACC-CTCGATGC,6,Lin-Kit+Sca1-,Undifferentiated,2,-587.462,-306.925,1440
2,d6_2_2,CGTACCGA-AGCGCCTT,6,Lin-Kit+Sca1-,Monocyte,2,1429.805,-429.300,394
3,d6_2_2,CTGAAGGG-AGGAGCTT,6,Lin-Kit+Sca1-,Neutrophil,2,1150.028,-2030.369,-1
4,d6_2_2,CCGTAGCT-AGGCAGTT,6,Lin-Kit+Sca1-,Undifferentiated,2,-1169.594,362.010,1972
...,...,...,...,...,...,...,...,...,...
130882,LSK_d6_1_3,TCTGATTT-CGGGCTTT,6,Lin-Kit+Sca1+,Undifferentiated,1,-308.468,-163.223,-1
130883,LSK_d6_1_3,AGTCACAA-TGTGTCCT,6,Lin-Kit+Sca1+,Undifferentiated,1,-3.435,575.133,1374
130884,LSK_d6_1_3,GGAGGTTT-AGGCAGTT,6,Lin-Kit+Sca1+,Monocyte,1,2548.309,24.683,-1
130885,LSK_d6_1_3,CCGGAAAT-GGGAAGGT,6,Lin-Kit+Sca1+,Monocyte,1,2658.601,131.098,-1


In [9]:
print("number of lineages: ", len(adata.obs['clone_id'].unique()))

number of lineages:  5865


### using the single-cell dataloader to generate batches

In [10]:
import DataLoader_tensor_sparse as dl
import SCDataset as ds

In [11]:
# input data
count_matrix = adata.X
cell_lineage = adata.obs['clone_id'].values.reshape(-1, 1)
count_matrix.shape, cell_lineage.shape

((130887, 25289), (130887, 1))

In [12]:
# step 1 generate designed batches
batchsize = 10
DLoader = dl.SClineage_DataLoader(count_matrix,cell_lineage,batch_size=batchsize, seed=7)
batch_all, num_batch = DLoader.batch_generator()
# step 2 generate real dataloader
sc_dataset = ds.SCDataset(batches=batch_all)

print("number of batches: ", num_batch)

number of batches:  43314


### using the real torch dataloader

In [13]:
data_loader = torch.utils.data.DataLoader(dataset=sc_dataset, batch_size=batchsize, shuffle=False)