In [None]:
# Variational autoencoder with 1 hidden layer

class VAE(nn.Module):
    
    def __init__(self, obs_dim, z_dim):
        
        super(AE, self).__init__()
        self.fc1 = nn.Linear(obs_dim, z_dim)
        self.fc2 = nn.Linear(obs_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, obs_dim)
        
    def encode(self, x):
        
        return self.fc1(torch.logit(x, eps = 0.01)), self.fc2(torch.logit(x, eps = 0.01))
    
    def reparameterize(self, mu, log_var):
        
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        
        return mu + eps * std

    def decode(self, z):
        
        return torch.sigmoid(self.fc3(z))
    
    def forward(self, x):
    
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst_mu = self.decode(z)
        
        return x_reconst_mu, mu, log_var

In [None]:
# Hyperparameters

from torch.utils.data import DataLoader

obs_dim = SNP.shape[1]
z_dim = 200
beta = 0.5
learning_rate = 0.0005
batch_size = SNP.shape[0]
num_epochs = 1000
loss_fn = nn.BCELoss(reduction = 'none')
data_loader = DataLoader(mega, batch_size)

In [None]:
# Training

import time

start_time = time.time()
model = VAE(obs_dim, z_dim)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
cost_recon = np.empty(num_epochs)
cost_div = np.empty(num_epochs)

for epoch in range(num_epochs):

    for batch, x in enumerate(data_loader):

        x_reconst_mu, mu, log_var = model(x[:, : SNP.shape[1]])
        kl_div = - 0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        recon_loss = torch.mean(loss_fn(x_reconst_mu, x[:, : SNP.shape[1]]) * x[:, SNP.shape[1] : ])
        loss = recon_loss + beta * kl_div

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    x_reconst_mu, mu, log_var = model(mega[:, : SNP.shape[1]])    
    cost_div[epoch] = - 0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
    cost_recon[epoch] = torch.mean(loss_fn(x_reconst_mu, mega[:, : SNP.shape[1]]) * mega[:, SNP.shape[1] : ])
    
end_time = time.time()
latent = model.encode(SNP)[0].detach()