In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=270, latent_dim=54, hidden_dim=128):
        super(VAE, self).__init__()
        
        self.encoder_fc = nn.Linear(input_dim, hidden_dim)
        self.encoder_mu = nn.Linear(hidden_dim, latent_dim)
        self.encoder_logvar = nn.Linear(hidden_dim, latent_dim)

        self.decoder_fc1 = nn.Linear(latent_dim, hidden_dim)
        self.decoder_fc2 = nn.Linear(hidden_dim, input_dim)
    
    def encode(self, x):
        h = F.relu(self.encoder_fc(x))
        mu = self.encoder_mu(h)
        logvar = self.encoder_logvar(h)
        return mu, logvar
    
    def reparameterize(selfhow will my readme look, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.decoder_fc1(z))
        return torch.sigmoid(self.decoder_fc2(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

In [None]:
def loss_fn(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

In [None]:
def train_vae(vae, dataloader, optimizer, epochs=50, device='cpu'):
    vae.to(device)
    vae.train()

    for epoch in range(epochs):
        total_loss = 0.0
        for x in dataloader:
            x = x.to(device)
            optimizer.zero_grad()

            recon, mu, logvar = vae(x)
            loss = loss_fn(recon, x, mu, logvar)

            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)

        avg_loss = total_loss / len(dataloader.dataset)
        print(f"Epoch {epoch+1} | Loss: {avg_loss}")