In [None]:
import torch
import utils
import time
import numpy as np
from torch import nn
from torch.utils import data
from concreteNet import ConcreteAutoencoder

In [None]:
file_path = "../../data/filtered_gene_bc_matrices/hg19"
device = utils.get_device()
#device = torch.device("cpu")

In [None]:
adata = utils.anndata_load(file_path)
adata = utils.anndata_preprocess(adata)
adata.shape

In [None]:
def train(net, loss, dataloader, num_epoch, learning_rate, weight_decay, device):
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, weight_decay=weight_decay)
    start_time  = time.time()
    net.train()
    for epoch in range(num_epoch):
        train_loss = 0
        for data in dataloader:
            input = data.to(device, non_blocking = True)
            optimizer.zero_grad()
            reconstruction, _ ,_ = net(input)
            l = loss(reconstruction, input)
            l.backward()
            optimizer.step()
            train_loss += l
        process_time = time.time() - start_time
        train_loss /= len(dataloader)
        print("Epoch: %d ; Loss %.5f; Time: %.2f s" %(epoch, train_loss, process_time))

#def inference(net, dataloader):
    

In [None]:
def get_net(input_dim, k, hidden_dim,device, temperature):
    return ConcreteAutoencoder(input_dim, k,hidden_dim,device, temperature)
def load_adata(adata, batch_size):
    return utils.data_loader(adata, batch_size)

In [None]:
num_epoch, learning_rate, weight_dacay, batch_size, temperature = 128, 0.05, 0, 256, 0.1
input_dim = adata.n_vars
loss = nn.MSELoss()
net = get_net(input_dim=input_dim, k=1000,hidden_dim=128,device=device, temperature=temperature)
dataloader=load_adata(adata=adata, batch_size=batch_size)

In [None]:
train(net=net, loss=loss, dataloader=dataloader,
      num_epoch=num_epoch, learning_rate=learning_rate,
      weight_decay=weight_dacay, device=device)

In [None]:
x_val = torch.Tensor(adata.X.A).to(device)
_, z, m = net(x_val)
z.shape

In [None]:
z = z.to('cpu').detach().numpy()
m = m.to('cpu').detach().numpy()

In [None]:
max_idx = []
for i in range(m.shape[1]):
    max_idx.append(np.argmax(m[:,i]))
max_idx

In [None]:
for i in range(m.shape[1]):
    print (max(m[:,i]))

In [None]:
def find_same(index):
    count = 0
    for i in range(1,len(index)):
        for j in range(i):
            if index[j] == index[i]:
                count +=1
    print(count)  

In [None]:
find_same(max_idx)

In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

In [None]:
res_adat = ad.AnnData(z)

In [None]:
res_adat

In [None]:
sc.tl.pca(res_adat, svd_solver='arpack')

In [None]:
sc.pl.pca(res_adat)

In [None]:
sc.pp.neighbors(res_adat, n_neighbors=10, n_pcs=40)

In [None]:
sc.tl.umap(res_adat)
sc.pl.umap(res_adat)

In [None]:
sc.tl.leiden(res_adat)
sc.pl.umap(res_adat, color='leiden')

In [None]:
res_genes = adata.var_names[max_idx]

In [None]:
res_adat.var_names = res_genes # type: ignore

In [None]:
sc.tl.rank_genes_groups(res_adat, 'leiden', method='t-test')
sc.pl.rank_genes_groups(res_adat, n_genes=25, sharey=False)

In [None]:
adata = utils.anndata_load(file_path=file_path)
adata = utils.anndata_preprocess(adata=adata,n_top_genes=1000)

In [None]:
var_genes = adata.var_names

In [None]:
res_adat.var_names_make_unique()

In [None]:
s = set(res_adat.var_names)

In [None]:
temp1 = [x for x in var_genes if x not in s]

In [None]:
len(temp1)

In [None]:
temp2 = [x for x in s if x not in var_genes]

In [None]:
len(temp2)

In [None]:
temp1

In [None]:
temp2