In [1]:
import scanpy as sc
import pandas as pd
import lightning

In [2]:
import torch

In [3]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
import sklearn.model_selection

In [4]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)

class Cell_Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), 
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                               )

    def forward(self, x):
        return self.l1(x)
    
class Gene_Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), 
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(), 
                                nn.Linear(hidden_dim, hidden_dim),
                               )

    def forward(self, x):
        return self.l1(x)

In [5]:
def filter_nonzero(x):
    return F.softplus(x)

In [6]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder1, encoder2, gene_emb, train_index, eta=1e-4, nonneg=False):
        super().__init__()
        self.cellencoder = encoder1
        self.geneencoder = encoder2 
        self.gene_emb = gene_emb
        self.train_index = train_index
        self.eta = eta
        
        if nonneg:
            self.id = nn.Softplus()
        else:
            self.id = nn.Identity()

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        z_cell = self.cellencoder(x)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        loss = F.mse_loss(out_final[:,self.train_index], y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        z_cell = self.cellencoder(x)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        val_loss = F.mse_loss(out_final[:,self.train_index], y)
        self.log('val_loss', val_loss, on_epoch=True, on_step=False)
        return val_loss

        
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        z_cell = self.cellencoder(x)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        test_loss = F.mse_loss(out_final[:,self.train_index], y)
        self.log('test_loss', test_loss)
        return test_loss
        
    def forward(self, x):
        z_cell = self.cellencoder(x)
        z_gene = self.geneencoder(self.gene_emb)
        
        out_final = torch.matmul(z_cell, z_gene.T)
        out_final = self.id(out_final)
        
        return out_final

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.eta)
        return optimizer

In [7]:
adata = sc.read_h5ad("/home/tl688/pitl688/seqfish_data.h5ad")

In [9]:
adata = adata.T.copy()

In [None]:
adata_emb = sc.read_h5ad("./adata_mouse_embedding_enformer.h5ad")

In [None]:
adata_emb.obs_names = [i for i in range(len(adata_emb.obs_names))]

In [13]:
df_id = adata_emb.obs 

In [14]:
df_id = df_id[df_id['add_id'].isin(adata.obs_names)]

In [16]:
adata_train =  adata.T.copy()

In [17]:
train_index = [int(i) for i in df_id.index]

In [19]:
adata_train = adata_train[:,df_id['add_id'].values]

In [20]:
adata_train.X = adata_train.layers['normalized']

In [21]:
adata_train.X = adata_train.X.toarray()

In [99]:
import sklearn.model_selection
import numpy as np
import random
np.random.seed(2024)
random.seed(2024)

In [100]:
train_name, test_name = sklearn.model_selection.train_test_split(adata_train.obs_names)

In [101]:
train_name, valid_name = sklearn.model_selection.train_test_split(train_name)

In [102]:
cell_enc = Cell_Encoder(input_dim=adata_train.X.shape[1], hidden_dim=64)

In [103]:
gene_enc = Gene_Encoder(input_dim=3072,hidden_dim=64)

In [105]:
model = LitAutoEncoder(encoder1=cell_enc, encoder2=gene_enc, eta=1e-4, gene_emb=torch.FloatTensor(adata_emb.X).cuda(), train_index=train_index)

In [107]:
X_tr, X_val, X_test, y_tr, y_val, y_test =torch.FloatTensor(adata_train[train_name].X.toarray()),torch.FloatTensor(adata_train[valid_name].X.toarray()),torch.FloatTensor(adata_train[test_name].X.toarray()),torch.FloatTensor(adata_train[train_name].X.toarray()), torch.FloatTensor(adata_train[valid_name].X.toarray()), torch.FloatTensor(adata_train[test_name].X.toarray())

In [108]:
train_dataset = torch.utils.data.TensorDataset(X_tr, y_tr)
valid_dataset = torch.utils.data.TensorDataset(X_val, y_val)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)

In [109]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=2048, num_workers=1)
valid_loader = DataLoader(valid_dataset, batch_size=2048, num_workers=1)


In [None]:
# train with both splits
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
trainer = L.Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=100)], max_epochs=1000)
trainer.fit(model, train_loader, valid_loader, )

In [None]:
best_checkpoint = trainer.checkpoint_callback.best_model_path
print(best_checkpoint)
model = LitAutoEncoder.load_from_checkpoint(best_checkpoint, encoder1=cell_enc, encoder2=gene_enc, eta=1e-4, gene_emb=torch.FloatTensor(adata_emb.X).cuda(), train_index=train_index)

In [37]:
model = model.cuda()

In [38]:
with torch.no_grad():
    imputed = model.forward(X_test.cuda()).cpu()

In [42]:
adata_imp = sc.AnnData(imputed.numpy())

In [43]:
adata_test = adata_train[test_name]

In [None]:
sc.tl.pca(adata_test)
sc.pp.neighbors(adata_test)
sc.tl.umap(adata_test)

In [None]:
sc.pl.umap(adata_test, color='scClassify')