In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class BasicVAE(nn.Module):
    def __init__(self, input_dim, hidden_dims, z_dim):
        super(BasicVAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], z_dim * 2)  # for mean and log-variance
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu, logvar = x.chunk(2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
class EncDecVAE_Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, z_dim):
        super(EncDecVAE_Encoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[2]),
            nn.ReLU()
        )

        # Decoder within Encoder
        self.encoder_decoder = nn.Sequential(
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], z_dim * 2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = self.encoder_decoder(x)
        mu, logvar = x.chunk(2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
class EncDecVAE_Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, z_dim):
        super(EncDecVAE_Decoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], z_dim * 2)  # for mean and log-variance
        )

        # Encoder within Decoder
        self.decoder_encoder = nn.Sequential(
            nn.Linear(z_dim, hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dims[1], hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu, logvar = x.chunk(2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.decoder_encoder(z)
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
class EncDecVAE_Both(nn.Module):
    def __init__(self, input_dim, hidden_dims, z_dim):
        super(EncDecVAE_Both, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[2]),
            nn.ReLU()
        )

        # Decoder within Encoder
        self.encoder_decoder = nn.Sequential(
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], z_dim * 2)
        )

        # Encoder within Decoder
        self.decoder_encoder = nn.Sequential(
            nn.Linear(z_dim, hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dims[1], hidden_dims[2]),
            nn.ReLU(),
            nn.Linear(hidden_dims[2], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = self.encoder_decoder(x)
        mu, logvar = x.chunk(2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.decoder_encoder(z)
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [None]:
# Define a transform to normalize the data to [0, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [None]:
# Adjust the normalization to [0, 1]
class NormalizeTo01:
    def __call__(self, tensor):
        return (tensor + 1) / 2

transform = transforms.Compose([
    transforms.ToTensor(),
    NormalizeTo01()
])


In [None]:
# Download and load the training data
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Download and load the test data
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


In [None]:
import torch.optim as optim

# Define the loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define the training loop
def train_vae(model, dataloader, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.view(data.size(0), -1)  # Flatten the images
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        print(f'Epoch {epoch + 1}, Loss: {train_loss / len(dataloader.dataset)}')

# Evaluate the model
def evaluate_vae(model, dataloader):
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.view(data.size(0), -1)
            recon_batch, mu, logvar = model(data)
            eval_loss += loss_function(recon_batch, data, mu, logvar).item()
    avg_loss = eval_loss / len(dataloader.dataset)
    print(f'Average Loss: {avg_loss}')
    return avg_loss

# Train and evaluate the basic VAE
input_dim = 784  # 28x28 images
hidden_dims = [400, 200, 100]
z_dim = 20
epochs = 10

# Initialize the model, optimizer
basic_vae = BasicVAE(input_dim=input_dim, hidden_dims=hidden_dims, z_dim=z_dim)
optimizer = optim.Adam(basic_vae.parameters(), lr=1e-3)

# Train the model
train_vae(basic_vae, train_loader, optimizer, epochs)

# Evaluate the model
print("Evaluating Basic VAE")
evaluate_vae(basic_vae, test_loader)


Epoch 1, Loss: 525.2857182942709
Epoch 2, Loss: 523.2812800130208
Epoch 3, Loss: 522.7997369140625
Epoch 4, Loss: 522.1077425130209
Epoch 5, Loss: 521.1180700520833
Epoch 6, Loss: 520.39628125
Epoch 7, Loss: 519.6862736979167
Epoch 8, Loss: 518.85702421875
Epoch 9, Loss: 518.2068593098959
Epoch 10, Loss: 517.7478166666667
Evaluating Basic VAE
Average Loss: 516.5412401367188


516.5412401367188

In [None]:
# Train and evaluate EncDecVAE_Encoder
encdec_vae_encoder = EncDecVAE_Encoder(input_dim=input_dim, hidden_dims=hidden_dims, z_dim=z_dim)
optimizer = optim.Adam(encdec_vae_encoder.parameters(), lr=1e-3)

train_vae(encdec_vae_encoder, train_loader, optimizer, epochs)
print("Evaluating EncDecVAE_Encoder")
evaluate_vae(encdec_vae_encoder, test_loader)

# Train and evaluate EncDecVAE_Decoder
encdec_vae_decoder = EncDecVAE_Decoder(input_dim=input_dim, hidden_dims=hidden_dims, z_dim=z_dim)
optimizer = optim.Adam(encdec_vae_decoder.parameters(), lr=1e-3)

train_vae(encdec_vae_decoder, train_loader, optimizer, epochs)
print("Evaluating EncDecVAE_Decoder")
evaluate_vae(encdec_vae_decoder, test_loader)

# Train and evaluate EncDecVAE_Both
encdec_vae_both = EncDecVAE_Both(input_dim=input_dim, hidden_dims=hidden_dims, z_dim=z_dim)
optimizer = optim.Adam(encdec_vae_both.parameters(), lr=1e-3)

train_vae(encdec_vae_both, train_loader, optimizer, epochs)
print("Evaluating EncDecVAE_Both")
evaluate_vae(encdec_vae_both, test_loader)


Epoch 1, Loss: 525.0125225260417
Epoch 2, Loss: 523.1357188151042
Epoch 3, Loss: 522.668647265625
Epoch 4, Loss: 522.3084804036458
Epoch 5, Loss: 522.2341761067709
Epoch 6, Loss: 522.5656405598959
Epoch 7, Loss: 521.963316796875
Epoch 8, Loss: 521.52902265625
Epoch 9, Loss: 521.05743046875
Epoch 10, Loss: 520.337030859375
Evaluating EncDecVAE_Encoder
Average Loss: 519.0538999511718
Epoch 1, Loss: 525.4752927083333
Epoch 2, Loss: 523.578536328125
Epoch 3, Loss: 522.5085984375
Epoch 4, Loss: 521.57775546875
Epoch 5, Loss: 520.7263766927083
Epoch 6, Loss: 520.3041289713542
Epoch 7, Loss: 519.9044121744791
Epoch 8, Loss: 519.4789244791667
Epoch 9, Loss: 519.1372115885416
Epoch 10, Loss: 518.9431379557292
Evaluating EncDecVAE_Decoder
Average Loss: 517.7775861328125
Epoch 1, Loss: 525.3061408203125
Epoch 2, Loss: 523.1367444010417
Epoch 3, Loss: 522.6147083333333
Epoch 4, Loss: 522.1848690104167
Epoch 5, Loss: 521.8696764973959
Epoch 6, Loss: 521.7331504557292
Epoch 7, Loss: 521.766788736979

523.8562215820313