In [133]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
df = pd.read_csv("/kaggle/input/disentangled-digits-a-beta-vae-challenge-on-mnist/mnist_images.csv")
ids = df["ID"].values
X = df.drop(columns=["ID"]).values.astype(np.float32)
X = X.reshape(-1, 1, 28, 28)
tensor_x = torch.tensor(X)
dataset = TensorDataset(tensor_x)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [134]:
class BetaVAE(nn.Module):
    def __init__(self, latent_dim=16, beta=0.1):
        super().__init__()
        self.beta = beta

        # Encoder
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), 
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(128*3*3, latent_dim)
        self.fc_logvar = nn.Linear(128*3*3, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 128*3*3)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )


    def encode(self, x):
        h = self.enc(x)
        h = h.view(h.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 128, 3, 3)
        return self.dec(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

In [135]:
def loss_function(recon_x, x, mu, logvar, beta):
    # Reconstruction loss
    recon_loss = F.mse_loss(recon_x, x, reduction="sum")
    # KL divergence
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl, recon_loss, kl

In [136]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BetaVAE(latent_dim=20, beta=0.1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 25
model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch, in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(batch)
        loss, rec, kl = loss_function(recon, batch, mu, logvar, model.beta)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataset):.4f}")

Epoch 1, Loss: 63.7815
Epoch 2, Loss: 26.0620
Epoch 3, Loss: 18.6079
Epoch 4, Loss: 16.1327
Epoch 5, Loss: 14.9340
Epoch 6, Loss: 14.1363
Epoch 7, Loss: 13.5454
Epoch 8, Loss: 13.0857
Epoch 9, Loss: 12.7320
Epoch 10, Loss: 12.4393
Epoch 11, Loss: 12.2030
Epoch 12, Loss: 11.9920
Epoch 13, Loss: 11.8160
Epoch 14, Loss: 11.6662
Epoch 15, Loss: 11.5281
Epoch 16, Loss: 11.4163
Epoch 17, Loss: 11.3128
Epoch 18, Loss: 11.2103
Epoch 19, Loss: 11.1341
Epoch 20, Loss: 11.0538
Epoch 21, Loss: 10.9709
Epoch 22, Loss: 10.9108
Epoch 23, Loss: 10.8432
Epoch 24, Loss: 10.7824
Epoch 25, Loss: 10.7294


In [137]:
model.eval()
recons = []
with torch.no_grad():
    for batch, in DataLoader(dataset, batch_size=256):
        batch = batch.to(device)
        recon, _, _ = model(batch)
        recons.append(recon.cpu().numpy())
recons = np.concatenate(recons, axis=0)
recons_flat = recons.reshape(-1, 784)

In [138]:
preds = [" ".join(map(lambda x: f"{x:.6f}", row)) for row in recons_flat]

submission = pd.DataFrame({
    "ID": ids,
    "Predicted_reconstruction": preds
})
submission.to_csv("submission_beta_vae.csv", index=False)