In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
# 读取基因表达数据
gene_expression_data = pd.read_csv('../data/scRNA_seq_for_hw3_hw5.tsv', delimiter='\t', index_col=0)

# 读取标签数据
labels = pd.read_csv('../data/label_for_hw3_hw5.tsv',delimiter='\t', index_col=0)

In [4]:
gene_expression_data= gene_expression_data.T

In [3]:
labels.shape

(10412, 1)

In [6]:
gene_data = np.array(gene_expression_data.values, dtype=np.float32)
print(np.isnan(gene_data).any())

False


In [7]:
class MyDataset(Dataset):
    def __init__(self, gene_expression_data):
        self.genes = torch.tensor(gene_expression_data, dtype=torch.float32) 

    def __len__(self):
        return len(self.genes)

    def __getitem__(self, idx):
        return self.genes[idx]

In [26]:
class MyVAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(MyVAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim*2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim) # concatenate rna and atac
        )
    
    # reparameterization trick for vae sampling
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std   

    def forward(self, x):
        # Encode scRNA-seq data
        h = self.encoder(x)
        mu, logvar = torch.chunk(h, 2, dim=-1)
        z = self.reparameterize(mu, logvar)

        # Concatenate latent variables
        x_hat = self.decoder(z)

        return x_hat, mu, logvar, h, z    

In [18]:
type(gene_expression_data)

pandas.core.frame.DataFrame

In [24]:
input_dim = gene_expression_data.shape[-1]
print(input_dim)
latent_dim = 256

model = MyVAE(input_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
dataset = MyDataset(gene_data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3000


In [38]:
num_epochs = 150

In [39]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
avg_losses = []
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch_idx, x in enumerate(dataloader):
        
        optimizer.zero_grad()
        x_hat, mu, logvar, h, z  = model(x)

        # flatten input and output
        x_flat = x.view(-1, input_dim)
        
        x_hat_flat = x_hat.view(-1, input_dim)

        loss = criterion(x_hat_flat, x_flat)
        # print(loss)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        
    # Print or log the average loss for the epoch
    avg_loss = total_loss / len(dataloader)
    avg_losses.append(avg_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss}')

    # Log the loss to TensorBoard
    writer.add_scalar('Loss/Train', avg_loss, epoch)

    # Save the model after each epoch
    if (epoch+1)%50==0: 
        model_save_path = f'model_epoch_{epoch+1}.pt'
        torch.save(model.state_dict(), model_save_path)
writer.close()

plt.figure(figsize=(10, 5))
plt.plot(avg_losses)
plt.title('Loss over time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig('loss.png')
plt.show()

# Umap可视化

In [None]:
import umap
import numpy as np
import scanpy as sc
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from MulticoreTSNE import MulticoreTSNE as mTSNE
import umap
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
x_hat, mu, logvar, h, z  = model(torch.tensor(gene_data, dtype=torch.float32))
z = z.detach().numpy()
print(z.shape)

(10412, 256)


In [None]:
umap_result = umap.UMAP().fit_transform(z)
gene_expression_data['UMAP1'] = umap_result[:,0]
gene_expression_data['UMAP2'] = umap_result[:,1]

In [None]:
gene_expression_data['label'] = labels['label'] 

In [None]:
plt.figure(figsize=(8, 6))
sns.scatterplot(x='UMAP1', y='UMAP2', hue='label', data=gene_expression_data, palette='viridis')
plt.title('vis for VAE')
plt.legend(bbox_to_anchor=(1, 0), loc='lower left', borderaxespad=0., ncol=1)
plt.savefig('hw5_vis', bbox_inches='tight')
plt.close()