# Deep learning-based algorithm for methylation profile reconstruction

* Motivation
  * DNA methylation data includes high-dimensional biological signals which is difficult to be used for downstream analysis. Popular feature selection strategy like choosing most variable or differential signals introduce the bias. We introduce the a deep learning model in autoencoder fashion to address this problem. The high-dimensional methylation data will be reconstructed by a low-dimensional latent space. 
* Model selection
  * Autoencoder
  * Variational autoencoder


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Beta
from torch.utils.data import DataLoader, TensorDataset


In [7]:
device = torch.device('cpu' if not torch.cpu.is_available() else 'mps')

In [8]:
# Define a simple VAE with Beta distribution as the latent variable

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc_alpha = nn.Linear(128, latent_dim)  # for alpha parameter
        self.fc_beta = nn.Linear(128, latent_dim)   # for beta parameter
        
        # Decoder
        self.fc2 = nn.Linear(latent_dim, 128)
        self.fc3 = nn.Linear(128, input_dim)
        
    def encode(self, x):
        x = F.relu(self.fc1(x))
        alpha = torch.exp(self.fc_alpha(x)) + 1e-6  # adding small epsilon to avoid zero
        beta = torch.exp(self.fc_beta(x)) + 1e-6    # adding small epsilon to avoid zero
        return alpha, beta
    
    def reparameterize(self, alpha, beta):
        # Sample from Beta distribution
        batch_size = alpha.size(0)
        epsilon = torch.rand_like(alpha)
        z = alpha * epsilon + beta * (1.0 - epsilon)
        return z
    
    def decode(self, z):
        z = F.relu(self.fc2(z))
        x_recon = torch.sigmoid(self.fc3(z))  # assuming input data is normalized to [0, 1]
        return x_recon
    
    def forward(self, x):
        # Encode
        alpha, beta = self.encode(x)
        
        # Reparameterize and sample z
        z = self.reparameterize(alpha, beta)
        
        # Decode
        x_recon = self.decode(z)
        
        return x_recon, alpha, beta
    
    def compute_loss(self, x, x_recon, alpha, beta, beta_factor=1.0):
        # Reconstruction loss (MSE)
        reconstruction_loss = F.mse_loss(x_recon, x, reduction='mean')
        
        # KL divergence between Beta(alpha, beta) and Beta(1, 1)
        prior_alpha = torch.ones_like(alpha)
        prior_beta = torch.ones_like(beta)
        
        kl_divergence = self.compute_beta_kl(alpha, beta, prior_alpha, prior_beta)
        
        # Total loss
        total_loss = reconstruction_loss + beta_factor * kl_divergence
        
        return total_loss, reconstruction_loss, kl_divergence
    
    def compute_beta_kl(self, alpha_q, beta_q, alpha_p, beta_p):
        # Compute KL divergence between Beta distributions
        kl_divergence = (torch.lgamma(alpha_p) + torch.lgamma(beta_p) - torch.lgamma(alpha_p + beta_p) -
                         torch.lgamma(alpha_q) - torch.lgamma(beta_q) + torch.lgamma(alpha_q + beta_q) +
                         (alpha_q - alpha_p) * (torch.digamma(alpha_q) - torch.digamma(alpha_q + beta_q)) +
                         (beta_q - beta_p) * (torch.digamma(beta_q) - torch.digamma(alpha_q + beta_q)))
        
        return kl_divergence.mean()


In [18]:
# Example usage

# Initialize the VAE
input_dim = 500
latent_dim = 128
vae = VAE(input_dim, latent_dim)

# Move the VAE to the device
vae.to(device)

# Define optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Training loop
num_epochs = 20
beta_factor = 1.0  # weight of KL divergence term

# Initialize data loader
train_data = torch.randn(1000, input_dim)
test_data = torch.randn(100, input_dim)
train_dataset = TensorDataset(train_data)
test_dataset = TensorDataset(test_data)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [19]:
model = VAE(input_dim, latent_dim) 
alpha, beta = model.encode(test_data)
z = model.reparameterize(alpha, beta)
print(z.shape)

torch.Size([100, 128])


In [12]:
def train_loop(model, train_loader, optimizer):
    for epoch in range(num_epochs):
        vae.train()
        total_loss = 0.0
        total_rec_loss = 0.0
        total_kl_loss = 0.0
        
        for batch_idx, (data, ) in enumerate(train_loader):
            data = data.view(-1, input_dim)
            data = data.to(device)  # Move data to the device
            optimizer.zero_grad()
            
            # Forward pass
            recon_batch, alpha, beta = vae(data)
            
            # Compute loss
            loss, rec_loss, kl_loss = vae.compute_loss(data, recon_batch, alpha, beta, beta_factor)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Accumulate losses
            total_loss += loss.item()
            total_rec_loss += rec_loss.item()
            total_kl_loss += kl_loss.item()
            
        # Print average losses
        avg_loss = total_loss / len(train_loader)
        avg_rec_loss = total_rec_loss / len(train_loader)
        avg_kl_loss = total_kl_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Rec Loss: {avg_rec_loss:.4f}, KL Loss: {avg_kl_loss:.4f}")


Epoch [1/20], Loss: 1.0709, Rec Loss: 1.0465, KL Loss: 0.0244
Epoch [2/20], Loss: 1.0025, Rec Loss: 0.9981, KL Loss: 0.0044
Epoch [3/20], Loss: 0.9988, Rec Loss: 0.9975, KL Loss: 0.0013
Epoch [4/20], Loss: 0.9980, Rec Loss: 0.9974, KL Loss: 0.0006
Epoch [5/20], Loss: 0.9977, Rec Loss: 0.9973, KL Loss: 0.0004
Epoch [6/20], Loss: 0.9975, Rec Loss: 0.9973, KL Loss: 0.0002
Epoch [7/20], Loss: 0.9974, Rec Loss: 0.9973, KL Loss: 0.0002
Epoch [8/20], Loss: 0.9974, Rec Loss: 0.9973, KL Loss: 0.0001
Epoch [9/20], Loss: 0.9974, Rec Loss: 0.9973, KL Loss: 0.0001
Epoch [10/20], Loss: 0.9973, Rec Loss: 0.9973, KL Loss: 0.0001
Epoch [11/20], Loss: 0.9973, Rec Loss: 0.9973, KL Loss: 0.0001
Epoch [12/20], Loss: 0.9973, Rec Loss: 0.9973, KL Loss: 0.0001
Epoch [13/20], Loss: 0.9973, Rec Loss: 0.9973, KL Loss: 0.0001
Epoch [14/20], Loss: 0.9974, Rec Loss: 0.9972, KL Loss: 0.0001
Epoch [15/20], Loss: 0.9974, Rec Loss: 0.9972, KL Loss: 0.0002
Epoch [16/20], Loss: 0.9974, Rec Loss: 0.9973, KL Loss: 0.0002
E

In [None]:
def test_loop(dataloader, model):
    for i, data in enumerate(dataloader):
        x = data[0].to(device)
        x_recon, alpha, beta = model(x)
        z = model.reparameterize(alpha, beta)
    return x_recon, z