In [None]:
# pip install plotly

In [None]:
# pip install umap

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pathlib import Path

from learn.dataset import TabularDataset
from learn.VAE import Autoencoder,VariationalAutoencoder
from learn.train import train_model

import umap
import plotly.express as px
import plotly.graph_objects as go

In [None]:
data_path = Path("./data/")
# list(data_path.iterdir())

In [None]:
# import os
# os.getcwd()

In [None]:
rna = pd.read_csv(data_path/"rna_scale.csv", index_col=0).T
rna = rna.reset_index(drop=True)
# rna.head()

In [None]:
print(rna.shape)

In [None]:
train, valid = train_test_split(rna.to_numpy(dtype=np.float32), test_size=0.1, random_state=0)
# print(train.shape, valid.shape)
# print(train[0])
nfeatures = rna.shape[1]
# print(nfeatures)

In [None]:
train_ds = TabularDataset(train)
valid_ds = TabularDataset(valid)

In [None]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=False)

In [None]:
x, y = next(iter(train_dl))
x.shape, y.shape

In [None]:
modelAE = Autoencoder(in_dims=nfeatures,latent_dims=20)

In [None]:
modelAE

In [None]:
from collections import defaultdict

In [None]:
lr = 1e-2
epochs = 50
model, losses = train_model(modelAE, train_dl, valid_dl, lr=lr, epochs=epochs)

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(1, epochs+1, 1), y=losses['train'],
                         mode='lines',
                         name='train'))
fig.add_trace(go.Scatter(x=np.arange(1, epochs+1, 1), y=losses['valid'],
                         mode='lines',
                         name='valid'))
fig.show()

In [None]:
# N = the number of cells
# G = the number of genes
# M = the number of GO terms
# gene_indices[i] = the list of indices of genes that are in the i-th GO term
[
    [0, 1],
    [2, 4, 6],
]

# create encoders
device = 'cuda:0'
# device = 'cpu'
latent_dim = 20
encoders = []
decoders = []
decoder_shared = Decoder(latent_dim, G).to(device)
for gene_idx in gene_indices:
    encoder = Encoder(len(gene_idx), latent_dim).to(device)
    decoder = Decoder(latent_dim, len(gene_idx)).to(device)
    encoders.append(encoder)
    decoders.append(decoder)

def step(x):
"""
x: a tensor of shape (batch size, G)
"""
    embeddings = torch.empty([M, latent_dim])
    for i, (gene_idx, encoder) in enumerate(zip(gene_indices, encoders)):
        embedding = encoder(x[:, gene_idx])
        embeddings[i] = embedding
    
    embedding_merged = embedding.mean(0)
    # may try self-attention
    
    xhat_list = []
    loss_list = []
    for i in range(M):
        xhat = decoders[i](embeddings[i]) # or use embedding_merged
        loss = criterion(xhat, x)
        xhat_list.append(xhat)
        loss_list.append(loss)
        
    #
    xhat = decoder_shared(embedding_merged)
    loss = criterion(xhat, x)
    
    return loss_list

for epoch in range(epochs):
    for model in itertools.chain(encoders, decoders): model.train()
    for x, y in train_dl:
        optimizer.zero_grad()
        loss_list = step(x)
        for loss in loss_list: loss.backward()
        optimizer.step()
        
    for model in itertools.chain(encoders, decoders): model.eval()
    with torch.no_grad():
        loss_total = 0
        for x, y in test_dl:
            loss_list = step(x)
            for loss in loss_list: loss_total += loss.item()