In [1]:
from torch.nn import functional as f
from torch import nn, optim
import torch as th
import numpy as np
from torch.utils.data import DataLoader,Dataset
from  preprocessing import load_data
import argparse
from tqdm import tqdm
import scanpy as sc
from torch.distributions import Normal
device = th.device("cuda" if th.cuda.is_available() else "cpu")



class VAE(nn.Module):
    def __init__(self,input_dim,hidden_layers,latent_dim):
        layers = [input_dim] + hidden_layers
        layers.append(latent_dim)
        self.encoder_block = nn.Sequential([nn.Sequential
                                            (nn.Linear(layers[i],layers[i+1]),nn.ReLU()) 
                                            for i in range(len(layers)-1)])
        self.decoder_block = nn.Sequential([nn.Sequential
                                            (nn.Linear(layers[i],layers[i-1]),nn.ReLU()) 
                                            for i in range(len(layers)-1,-1,-1)])
        self.mu = nn.Sequential(nn.Linear(latent_dim,latent_dim),nn.ReLU())
        self.logvar = nn.Sequential(nn.Linear(latent_dim,latent_dim),nn.ReLU())
        self.normal = Normal(0,1)
        self.latent_dim = latent_dim
    def reparametrize(self,mu:th.Tensor,logvar:th.Tensor):
        return(mu + (th.exp(0.5*logvar)*self.normal.sample(sample_shape=self.latent_dim)))
    def forward(self,x):
        encoded = self.encoder_block(x)
        mu,logvar = self.mu(encoded),self.logvar(encoded)
        decoded = self.decoder_block(self.reparametrize(mu,logvar))
        return decoded,mu,logvar
class TrainingCellData(Dataset):
    def __init__(self,adata,is_perturbed=0):
        self.adata = adata
        self.is_perturbed = is_perturbed
    def __getitem__(self, index):
        gene_vector = th.zeros(self.adata.shape[1])
        gene_vector = self.adata.X[index]
        return gene_vector
    def __len__(self):
        return(self.adata.shape[0])
    


In [8]:
train_adata = sc.read_h5ad(r"C:\Users\saira\OneDrive\Desktop\scgen\train_pbmc.h5ad")
val_adata= sc.read_h5ad(r"C:\Users\saira\OneDrive\Desktop\scgen\train_pbmc.h5ad")
cell_types = list(set(train_adata.obs["cell_type"]))
print(cell_types)
holdout_cell = np.random.choice(np.array(cell_types))
print(holdout_cell)
cell_mask = (train_adata.obs["cell_type"]==holdout_cell)
cell_mask = cell_mask.values
holdoutcelltype_adata = train_adata[cell_mask]
holdout_control = holdoutcelltype_adata[holdoutcelltype_adata.obs["condition"] =="control"]
holdout_stim = holdoutcelltype_adata[holdoutcelltype_adata.obs["condition"] =="stimulated"]
train_adata = train_adata[0==cell_mask]


This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(


['NK', 'FCGR3A+Mono', 'B', 'Dendritic', 'CD14+Mono', 'CD4T', 'CD8T']
CD4T



This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(


In [16]:
def find_common_names(stimulated_names, control_names):
    processed_stimulated = [name.replace('-stimulated', '') for name in stimulated_names]
    processed_control = [name.replace('-control', '') for name in control_names]
    common_names = set(processed_stimulated).intersection(processed_control)
    print("Common Names:", common_names)
find_common_names(holdout_stim.obs_names, holdout_control.obs_names.values)

Common Names: {'GTAGCCCTAGACTC-1', 'CTCCACGAACGGGA-1', 'GGAACACTTTCGGA-1', 'TAGGTTCTTCTACT-1', 'ACAAATTGTAGAAG-1', 'TAAGGCTGTCAAGC-1', 'GAATGGCTCTCAAG-1', 'GGGATGGATGCCCT-1', 'CACCTGACTGACTG-1'}


In [None]:
class HoldoutCellData(Dataset):
    def __init__(self,adata_stim,adata_ctrl):
        self.adata_ctrl = adata_ctrl
        self.adata_stim = adata_stim
    def __getitem__(self,idx):
        ctrl_vector = self.adata_ctrl.X[idx]
        cell_name = self.adata_ctrl.index[idx]
        try:
            stim_vector = self.adata_stim[f"{cell_name}-stimulated"]
        except:
            stim_vector = None
        return(ctrl_vector,stim_vector)
    def __len__(self):
        return(self.adata_ctrl.shape[0])
    
holdout_data = HoldoutCellData(holdout_stim,holdout_control)
print(len(holdout_data))
holdoutloader = DataLoader(holdout_data)
for ctrl,stim in holdoutloader:
    print(ctrl,stim)
    
    


2437


AttributeError: 'AnnData' object has no attribute 'iloc'

In [7]:
print(holdout_stim)

View of AnnData object with n_obs × n_vars = 0 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'


In [None]:

hvg_size = 2000
train_set = TrainingCellData(train_adata)
val_set = TrainingCellData(val_adata)
trainloader = DataLoader(train_set,batch_size=4)
valloader =DataLoader(val_set,batch_size=4)

hidden_layers = [512,256,128]
latent_dim=64
num_epochs = 100
lr = 3e3
model = VAE(hvg_size,hidden_layers,latent_dim).to(device)
optimizer = optim.adagrad(VAE.parameters(),lr)

with tqdm(total=len(trainloader)*num_epochs) as pbar:
    ### training phase
    for epoch in range(1,num_epochs+1):
        epoch_wise_recon_loss =0
        epoch_wise_kl_loss = 0
        for batch in trainloader:
            optimizer.zero_grad()
            reconstruction,mu,logvar = model(batch.to(device))
            recon_loss = f.mse_loss(reconstruction,batch)
            kl_loss = 0.5*(-logvar -1 +th.exp(logvar) + mu@mu.T).sum()
            loss = recon_loss + kl_loss
            loss.backward()
            optimizer.step()
            pbar.write(f'Loss: {loss.item()}')
            pbar.update()
            epoch_wise_recon_loss += recon_loss
            epoch_wise_kl_loss += kl_loss
        epoch_wise_kl_loss = epoch_wise_kl_loss/len(trainloader)
        epoch_wise_recon_loss = epoch_wise_recon_loss/len(trainloader)
        print(f"Recon Loss for epoch {epoch} is {epoch_wise_recon_loss}")
        print(f"KL Loss for epoch {epoch} is {epoch_wise_kl_loss}")

        ## validation phase.
        with tqdm(total=len(valloader),leave=False) as pbar:
            epoch_wise_recon_loss =0
            epoch_wise_kl_loss = 0
            model.eval()
            with th.no_grad():
                for batch in valloader:
                    reconstruction,mu,logvar = model(batch.to(device))
                    recon_loss = f.mse_loss(reconstruction,batch)
                    kl_loss = 0.5*(-logvar -1 +th.exp(logvar) + mu@mu.T).sum()
                    epoch_wise_recon_loss += recon_loss.item()
                    epoch_wise_kl_loss += kl_loss
                epoch_wise_kl_loss = epoch_wise_kl_loss/len(valloader)
                epoch_wise_recon_loss = epoch_wise_recon_loss/len(valloader)
                print(f"Val Recon Loss for epoch {epoch} is {epoch_wise_recon_loss}")
                print(f"Val KL Loss for epoch {epoch} is {epoch_wise_kl_loss}")
            ## now let's calculate  how well it predicts the 100 degs
            ## and how well it predicts the actual 
            with th.no_grad():
                for cells in holdout_set:
                    unperturbed_latent,mu,logvar = model(cells)
                    




                

            

            


        
    





        