In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Normal
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

In [2]:
class Encoder(nn.Module):
    def __init__(
        self,
        input_dim=2000,
        layers=[512,256],
        latent_dim=50,
    ):
        super().__init__()

        hidden_dim = layers[-1]

        hidden_layers = []
        for i in range(len(layers) - 1):
            hidden_layers.append(nn.Linear(layers[i], layers[i + 1]))
            hidden_layers.append(nn.Tanh())
            # hidden_layers.append(nn.Sigmoid())
        
        self.encoder_layers = nn.Sequential(
            nn.Linear(input_dim, layers[0]),  # input_layer
            *hidden_layers  # Unpack the list of layers
        )

        self.var_enc = nn.Sequential(nn.Linear(hidden_dim, latent_dim))
        self.mu_enc = nn.Sequential(nn.Linear(hidden_dim, latent_dim))

    def reparameterize(self, mu, var):
        return Normal(mu, var.sqrt()).rsample()


    def forward(self, x):
        # Pass input through encoder layers
        x = self.encoder_layers(x)

        # Compute mean and variance
        mu = self.mu_enc(x)
        # make sure var>0
        var = torch.clamp(torch.exp(self.var_enc(x)), min=1e-20)
        z = self.reparameterize(mu, var)

        return z, mu, var


In [3]:
class Decoder(nn.Module):
    def __init__(
        self,
        output_dim=2000,
        latent_dim=50,
        layers=[256,512],
        is_norm_init=True,
    ):
        super().__init__()

        hidden_layers = []
        for i in range(len(layers) - 1):
            hidden_layers.append(nn.Linear(layers[i], layers[i + 1]))
            hidden_layers.append(nn.Tanh())
            # hidden_layers.append(nn.Sigmoid())
        
        self.decoder_layers = nn.Sequential(
            nn.Linear(latent_dim, layers[0]),  # input_layer
            *hidden_layers  # Unpack the list of layers
        )

        self.out_layer = nn.Sequential(
            nn.Linear(layers[-1], output_dim), nn.Sigmoid()
        )

        if is_norm_init:
            self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)



    def forward(self, x):
        x = self.decoder_layers(x)
        x = self.out_layer(x)
        
        return x

In [4]:
class VAE(nn.Module):
    def __init__(
        self,
        input_dim=2000,
        latent_dim=50,
        enc_layers=[512,256],
        dec_layers=[256,512],
        is_initialize=True,
        dec_norm_init=True,
    ):
        super().__init__()
        # for parameter record
        self.latent_dim = latent_dim

        # z, mu, var
        self.encoder = Encoder(
            input_dim=input_dim,
            layers=enc_layers,
            latent_dim=latent_dim,
        )

        self.decoder = Decoder(
            output_dim=2000,
            latent_dim=50,
            layers=[256,512],
            is_norm_init=dec_norm_init,
        )
        
        
        if is_initialize:
            self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)


### 模型定义检查

In [5]:
device = 'cuda:1'

net = VAE().to(device)

arr = np.random.random((32,2000))
data = torch.Tensor(arr).to(device)

In [6]:
z, mu, var = net.encoder(data)
result = net.decoder(z)

print(f'z:{z.shape}, result:{result.shape}')

z:torch.Size([32, 50]), result:torch.Size([32, 2000])


# 模型训练

In [7]:
import os
import torch
import torch.nn as nn
import numpy as np
import random
from tqdm import tqdm
import scanpy as sc
import torch.nn.functional as F
from collections import defaultdict
from torch.distributions import Normal, kl_divergence
from torch.optim import lr_scheduler

In [8]:
def kl_div(mu, var):
    return (
        kl_divergence(
            Normal(mu, var.sqrt()), Normal(torch.zeros_like(mu), torch.ones_like(var))
        )
        .sum(dim=1)
        .mean()
    )


In [9]:
def VAE_train(
    model,
    dataloader,
    num_epoch,
    kl_scale=0.5,
    device=torch.device("cuda:0"),
    lr=2e-4,
    seed=1234,
    is_lr_scheduler=True,
    weight_decay=5e-4,
):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] =str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoch, last_epoch=-1)

    tq = tqdm(range(num_epoch), ncols=80)
    for epoch in tq:
        model.train()
        epoch_loss = defaultdict(float)
        for i, (x,y) in enumerate(dataloader):
            x = x.float().to(device)

            z, mu, var = model.encoder(x)
            recon_x = model.decoder(z)

            # using bce loss estimating the error
            recon_loss = F.binary_cross_entropy(recon_x, x) * x.size(-1)
            kl_loss = kl_div(mu, var)
            loss = {"recon_loss": recon_loss, "kl_loss": kl_scale * kl_loss}

            optimizer.zero_grad()
            sum(loss.values()).backward()
            optimizer.step()

            for k, v in loss.items():
                epoch_loss[k] += loss[k].item()

        if is_lr_scheduler:
            scheduler.step()

        epoch_loss = {k: v / (i + 1) for k, v in epoch_loss.items()}
        epoch_info = ",".join(["{}={:.3f}".format(k, v) for k, v in epoch_loss.items()])
        tq.set_postfix_str(epoch_info)

       
    # for some config record
    return epoch_loss


In [10]:
from dataloader_VAE import get_h5ad_data,get_dataloader,normalize,inverse_normalize

In [11]:
num_epoch=20
batch_size=64

data_list = get_h5ad_data()
norm_data_list, scalers = normalize(data_list)
dataloader = get_dataloader(norm_data_list,batch_size=batch_size)

In [13]:
# def train_script(num_epoch,model_path=None, batch_size=64):



model = VAE()

device = torch.device("cuda:1")
net = VAE().to(device)

VAE_train(net,dataloader,num_epoch,device=device)


# model_save_path = + f"model/ckpt/" + model_path
# torch.save(model.state_dict(), model_save_path)


100%|█████████| 20/20 [00:39<00:00,  1.97s/it, recon_loss=249.323,kl_loss=5.581]


{'recon_loss': 249.32267700799622, 'kl_loss': 5.580858720017739}