In [1]:
import torch
import utils
import time
from torch import nn
from torch.utils import data
from concreteNet import ConcreteAutoencoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_path = "../../data/filtered_gene_bc_matrices/hg19"
device = utils.get_device()
#device = torch.device("cpu")

In [3]:
adata = utils.anndata_load(file_path)
adata = utils.anndata_preprocess(adata)
adata.shape

(2700, 5000)

In [4]:
def train(net, loss, dataloader, num_epoch, learning_rate, weight_decay, device):
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, weight_decay=weight_decay)
    start_time  = time.time()
    net.train()
    for epoch in range(num_epoch):
        train_loss = 0
        for data in dataloader:
            input = data.to(device, non_blocking = True)
            optimizer.zero_grad()
            reconstruction, _ = net(input)
            l = loss(reconstruction, input)
            l.backward()
            optimizer.step()
            train_loss += l
        process_time = time.time() - start_time
        train_loss /= len(dataloader)
        print("Epoch: %d ; Loss %.5f; Time: %.2f s" %(epoch, train_loss, process_time))


In [5]:
def get_net(input_dim, k, hidden_dim,device, temperature):
    return ConcreteAutoencoder(input_dim, k,hidden_dim,device, temperature)
def load_adata(adata, batch_size):
    return utils.data_loader(adata, batch_size)

In [6]:
num_epoch, learning_rate, weight_dacay, batch_size, temperature = 64, 0.1, 0, 128, 0.1
input_dim = adata.n_vars
loss = nn.MSELoss()
net = get_net(input_dim=input_dim, k=1000,hidden_dim=128,device=device, temperature=temperature)
dataloader=load_adata(adata=adata, batch_size=batch_size)

In [7]:
train(net=net, loss=loss, dataloader=dataloader,
      num_epoch=num_epoch, learning_rate=learning_rate,
      weight_decay=weight_dacay, device=device)


Epoch: 0 ; Loss 0.26548; Time: 1.81 s
Epoch: 1 ; Loss 0.25605; Time: 2.57 s
Epoch: 2 ; Loss 0.25623; Time: 3.32 s
Epoch: 3 ; Loss 0.25581; Time: 4.10 s
Epoch: 4 ; Loss 0.25633; Time: 4.91 s
Epoch: 5 ; Loss 0.25601; Time: 5.74 s
Epoch: 6 ; Loss 0.25604; Time: 6.48 s
Epoch: 7 ; Loss 0.25628; Time: 7.23 s
Epoch: 8 ; Loss 0.25690; Time: 7.97 s
Epoch: 9 ; Loss 0.25637; Time: 8.77 s
Epoch: 10 ; Loss 0.25667; Time: 9.62 s
Epoch: 11 ; Loss 0.25675; Time: 10.43 s
Epoch: 12 ; Loss 0.25588; Time: 11.23 s
Epoch: 13 ; Loss 0.25640; Time: 12.02 s
Epoch: 14 ; Loss 0.25643; Time: 12.84 s
Epoch: 15 ; Loss 0.25680; Time: 13.69 s
Epoch: 16 ; Loss 0.25633; Time: 14.53 s
Epoch: 17 ; Loss 0.25619; Time: 15.38 s
Epoch: 18 ; Loss 0.25568; Time: 16.20 s
Epoch: 19 ; Loss 0.25616; Time: 16.95 s
Epoch: 20 ; Loss 0.25617; Time: 17.78 s
Epoch: 21 ; Loss 0.25645; Time: 18.61 s
Epoch: 22 ; Loss 0.25647; Time: 19.43 s
Epoch: 23 ; Loss 0.25658; Time: 20.21 s
Epoch: 24 ; Loss 0.25594; Time: 21.02 s
Epoch: 25 ; Loss 0.25