In [1]:
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import anndata

In [2]:
import scipy.sparse
sparse_X = scipy.sparse.load_npz('data/filtered_Counts.npz')
annoData = pd.read_table('data/annoData.txt')
y = annoData["cellIden"].to_numpy()
high_var_gene = 6000
# normlization and feature selection
adataSC = anndata.AnnData(X=sparse_X, obs=np.arange(sparse_X.shape[0]), var=np.arange(sparse_X.shape[1]))
sc.pp.filter_genes(adataSC, min_cells=10)
adataSC.raw = adataSC
sc.pp.highly_variable_genes(adataSC, n_top_genes=high_var_gene, flavor='seurat_v3')
sc.pp.normalize_total(adataSC, target_sum=1e4)
sc.pp.log1p(adataSC)

adataNorm = adataSC[:, adataSC.var.highly_variable]
dataframe = adataNorm.to_df()
x_ndarray = dataframe.values.squeeze()
y_ndarray = np.expand_dims(y, axis=1)
print(x_ndarray.shape,y_ndarray.shape)
dataframe.head()

  if index_name in anno:


(8569, 6000) (8569, 1)


Unnamed: 0,1,2,4,7,10,13,26,31,32,33,...,20104,20105,20108,20109,20115,20118,20121,20122,20123,20124
0,1.024218,0.0,0.0,0.0,1.302199,0.0,0.0,0.0,0.637877,0.0,...,0.0,0.0,0.36896,0.0,0.0,0.0,0.637877,0.0,0.0,0.36896
1,0.0,0.0,0.0,0.0,1.351171,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.888292,0.0,0.305824,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.4175,0.0,0.0,0.0,0.0,0.0,0.0,...,0.4175,0.0,0.0,0.0,0.4175,0.0,0.93785,0.4175,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.509045,0.0,0.0,0.509045


In [3]:
from torch.utils.data import DataLoader,random_split,TensorDataset
scDataset = TensorDataset(torch.tensor(x_ndarray, dtype=torch.float32),
                              torch.tensor(y_ndarray, dtype=torch.float32))

scTrainLength = int(len(scDataset) * 0.8)
scValidLength = len(scDataset) - scTrainLength
scTrain, scValid = random_split(scDataset, [scTrainLength, scValidLength])

scTrainDataLoader = DataLoader(scTrain, shuffle=True, batch_size=256,drop_last=True)
scValidDataLoader = DataLoader(scValid, shuffle=True, batch_size=256,drop_last=True)

for features, labels in scTrainDataLoader:
    print(len(features[-1]))
    print(len(features))
    print(len(labels))
    break

6000
256
256


In [4]:
def gaussian_noise(original):
    cellShape=original.shape
    # print(original.device)
    noise=0.1*torch.randn(size=cellShape)
    return original+noise

In [5]:
for features, labels in scTrainDataLoader:
    print(features.shape)
    testData=gaussian_noise(features).cuda()
    print(testData.device)
    print(testData[0])
    print(features[0])
    break

torch.Size([256, 6000])
cuda:0
tensor([-0.0070, -0.0354, -0.0331,  ..., -0.0785,  0.1530,  0.0013],
       device='cuda:0')
tensor([0., 0., 0.,  ..., 0., 0., 0.])


In [9]:
def chiasma(original,prob,percentage):
    geneCount=original.shape[1]
    s=np.random.uniform(0,1)
    if s<prob:
        chiasma_instance=int(geneCount*percentage/2)
        chiasma_pair=np.random.randint(geneCount,size=(chiasma_instance,2))
        # print(chiasma_pair)
        copy=original.clone()
        copy[:,chiasma_pair[:,0]],copy[:,chiasma_pair[:,1]]=copy[:,chiasma_pair[:,1]],copy[:,chiasma_pair[:,0]]
        return copy

In [23]:
for features, labels in scTrainDataLoader:
    # print(features.shape)
    testData=chiasma(features,0.5,0.1)
    # print(testData.device)
    if testData!=None:
        print(testData[0])
        print(features[0])
        print((testData[0]-features[0]).sum())
    break

tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor(-1.9294)


In [43]:
def random_mask(original,prob,percentage):
    cellCount,geneCount=original.shape
    s=np.random.uniform(0,1)
    print(s)
    if s<prob:
        mask=np.concatenate([np.ones(int(geneCount*percentage),dtype=bool),np.zeros(geneCount-int(geneCount*percentage),dtype=bool)])
        np.random.shuffle(mask)
        copy=original.clone()
        copy[:,mask]=0
        return copy


In [47]:
for features, labels in scTrainDataLoader:
    # print(features.shape)
    testData=random_mask(features,0.5,0.15)
    # print(testData.device)
    if testData!=None:
        print(testData[0].sum())
        print(features[0].sum())

        print((testData[0]-features[0]).sum())
    break

0.4474919526869785
tensor(773.4749)
tensor(915.6711)
tensor(-142.1962)
