In [1]:
import torch
import utils
import time
import numpy as np
from torch import nn
from torch.utils import data
from concreteNet import ConcreteAutoencoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_path = "../../data/filtered_gene_bc_matrices/hg19"
device = utils.get_device()
#device = torch.device("cpu")

In [3]:
adata = utils.anndata_load(file_path)
adata = utils.anndata_preprocess(adata)
adata.shape

(2700, 4758)

In [4]:
def train(net, loss, dataloader, num_epoch, learning_rate, weight_decay, device):
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, weight_decay=weight_decay)
    start_time  = time.time()
    net.train()
    for epoch in range(num_epoch):
        train_loss = 0
        for data in dataloader:
            input = data.to(device, non_blocking = True)
            optimizer.zero_grad()
            reconstruction, _ ,_ = net(input)
            l = loss(reconstruction, input)
            l.backward()
            optimizer.step()
            train_loss += l
        process_time = time.time() - start_time
        train_loss /= len(dataloader)
        print("Epoch: %d ; Loss %.5f; Time: %.2f s" %(epoch, train_loss, process_time))

#def inference(net, dataloader):
    

In [5]:
def get_net(input_dim, k, hidden_dim,device, temperature):
    return ConcreteAutoencoder(input_dim, k,hidden_dim,device, temperature)
def load_adata(adata, batch_size):
    return utils.data_loader(adata, batch_size)

In [14]:
num_epoch, learning_rate, weight_dacay, batch_size, temperature = 128, 0.05, 0, 256, 0.1
input_dim = adata.n_vars
loss = nn.MSELoss()
net = get_net(input_dim=input_dim, k=1000,hidden_dim=128,device=device, temperature=temperature)
dataloader=load_adata(adata=adata, batch_size=batch_size)

In [15]:
train(net=net, loss=loss, dataloader=dataloader,
      num_epoch=num_epoch, learning_rate=learning_rate,
      weight_decay=weight_dacay, device=device)

Epoch: 0 ; Loss 4.04366; Time: 0.65 s
Epoch: 1 ; Loss 0.72354; Time: 1.30 s
Epoch: 2 ; Loss 0.59253; Time: 1.88 s
Epoch: 3 ; Loss 0.49475; Time: 2.48 s
Epoch: 4 ; Loss 0.43796; Time: 3.21 s
Epoch: 5 ; Loss 0.40687; Time: 4.19 s
Epoch: 6 ; Loss 0.38934; Time: 4.85 s
Epoch: 7 ; Loss 0.37917; Time: 5.50 s
Epoch: 8 ; Loss 0.37315; Time: 6.09 s
Epoch: 9 ; Loss 0.36934; Time: 6.62 s
Epoch: 10 ; Loss 0.36718; Time: 7.13 s
Epoch: 11 ; Loss 0.36606; Time: 7.70 s
Epoch: 12 ; Loss 0.36552; Time: 8.21 s
Epoch: 13 ; Loss 0.36512; Time: 8.81 s
Epoch: 14 ; Loss 0.36519; Time: 9.29 s
Epoch: 15 ; Loss 0.36473; Time: 9.87 s
Epoch: 16 ; Loss 0.36477; Time: 10.36 s
Epoch: 17 ; Loss 0.36461; Time: 10.85 s
Epoch: 18 ; Loss 0.36446; Time: 11.32 s
Epoch: 19 ; Loss 0.36456; Time: 11.97 s
Epoch: 20 ; Loss 0.36443; Time: 12.53 s
Epoch: 21 ; Loss 0.36475; Time: 13.20 s
Epoch: 22 ; Loss 0.36464; Time: 13.78 s
Epoch: 23 ; Loss 0.36498; Time: 14.39 s
Epoch: 24 ; Loss 0.36455; Time: 15.02 s
Epoch: 25 ; Loss 0.36503; 

In [8]:
x_val = torch.Tensor(adata.X.A).to(device)
_, z, m = net(x_val)

In [None]:
import scanpy
