In [201]:
import torch.nn.functional as F
import torch

class VAE_Loss:
    def __init__(self, mode='standard', alpha=1.0, gamma=1.0):
        """
        VAE Loss Function with multiple modes
        Args:
        - mode (str): The loss function mode ('standard', 'negative_kl', 'reduced_reconstruction', 'combined')
        - alpha (float): Weight for the reconstruction loss (default=1.0)
        - gamma (float): Weight for the KL divergence (default=1.0, can be negative in 'negative_kl' mode)
        """
        self.mode = mode
        self.alpha = alpha
        self.gamma = gamma

    def __call__(self, recon_x, x, z_mean, z_log_var):
        """
        Compute the VAE loss based on the selected mode
        Args:
        - recon_x: The reconstructed output from the decoder
        - x: The original input data
        - z_mean: The mean of the latent variable distribution
        - z_log_var: The log variance of the latent variable distribution
        Returns:
        - Loss (tensor): Computed loss value
        """
        # Reconstruction Loss: Can be MSE or binary cross-entropy, depending on your input/output type.
        reconstruction_loss = F.mse_loss(recon_x, x, reduction='sum')

        # KL Divergence Loss
        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())

        if self.mode == 'standard':
            # Standard VAE: Minimize both reconstruction and KL divergence
            return reconstruction_loss + kl_loss

        elif self.mode == 'negative_kl':
            # Encourage higher KL divergence by applying a negative weight to the KL term
            return reconstruction_loss - self.gamma * kl_loss

        elif self.mode == 'reduced_reconstruction':
            # Reduce the weight of reconstruction loss to allow more diversity in the latent space
            return self.alpha * reconstruction_loss + kl_loss

        elif self.mode == 'combined':
            # Combination: Reduce reconstruction weight and increase KL divergence
            return self.alpha * reconstruction_loss - self.gamma * kl_loss

        else:
            raise ValueError(f"Invalid mode selected: {self.mode}. Choose from ['standard', 'negative_kl', 'reduced_reconstruction', 'combined']")

# Example usage of VAE_Loss class

# Assume recon_x, x, z_mean, z_log_var are the outputs from the model

# Initialize the loss class
# vae_loss_standard = VAE_Loss(mode='standard')
# vae_loss_negative_kl = VAE_Loss(mode='negative_kl', gamma=1.5)
# vae_loss_reduced_reconstruction = VAE_Loss(mode='reduced_reconstruction', alpha=0.5)
# vae_loss_combined = VAE_Loss(mode='combined', alpha=0.5, gamma=1.5)

# Example: Using the loss function for each mode
# loss_standard = vae_loss_standard(recon_x, x, z_mean, z_log_var)
# loss_negative_kl = vae_loss_negative_kl(recon_x, x, z_mean, z_log_var)
# loss_reduced_reconstruction = vae_loss_reduced_reconstruction(recon_x, x, z_mean, z_log_var)
# loss_combined = vae_loss_combined(recon_x, x, z_mean, z_log_var)
# 
# print(f"Standard Loss: {loss_standard.item()}")
# print(f"Negative KL Loss: {loss_negative_kl.item()}")
# print(f"Reduced Reconstruction Loss: {loss_reduced_reconstruction.item()}")
# print(f"Combined Loss: {loss_combined.item()}")

In [202]:
from kl.utils import load_fx
import numpy as np
X, y = load_fx(data_len=5000, shift=2)
print(X.shape)

In [203]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Convert data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)


# VAE definition
class VAE(nn.Module):
    def __init__(self, input_dim=8, latent_dim=2):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2_mean = nn.Linear(128, latent_dim)
        self.fc2_logvar = nn.Linear(128, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, 128)
        self.fc4 = nn.Linear(128, input_dim)

    # def encode(self, x):
    #     h1 = F.relu(self.fc1(x))
    #     z_mean = self.fc2_mean(h1)
    #     z_log_var = self.fc2_logvar(h1)
    #     return z_mean, z_log_var
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        z_mean = self.fc2_mean(h1)
        z_log_var = self.fc2_logvar(h1)
        
        # Clamp z_log_var to avoid extreme values
        z_log_var = torch.clamp(z_log_var, min=-10, max=100)
    
        return z_mean, z_log_var

    def reparameterize(self, z_mean, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        return self.decode(z), z_mean, z_log_var

# Loss function (Reconstruction + KL divergence)
# def loss_function(recon_x, x, z_mean, z_log_var):
#     reconstruction_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
#     kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
#     return reconstruction_loss + kl_loss
# Loss function (Reconstruction using MSE + KL divergence)
# def loss_function(recon_x, x, z_mean, z_log_var):
#     reconstruction_loss = F.mse_loss(recon_x, x, reduction='sum')  # Using MSE instead of BCE
#     kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
#     return reconstruction_loss + kl_loss

# Initialize the VAE
input_dim = 8  # For your dataset with 8 features
latent_dim = 2
vae = VAE(input_dim, latent_dim)

# Example input batch (replace with your actual data later)
# x = torch.randn(64, input_dim)  # Batch of 64 examples, each with 8 features

# Forward pass
recon_x, z_mean, z_log_var = vae(X_tensor)

In [204]:
# Calculate initial KL divergence
initial_kl_divergence = -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()).item()
print(f"Initial KL Divergence: {initial_kl_divergence}")

In [205]:
# Calculate loss (for later usage during training)
vae_loss = VAE_Loss(mode='standard')
loss = vae_loss(recon_x, X_tensor, z_mean, z_log_var)
# loss = loss_function(recon_x, x, z_mean, z_log_var)

In [206]:
print(X_tensor.mean(), X_tensor.std())

In [207]:
# Extract the means and log variances from the latent space
z_mean_values = z_mean.detach().numpy()
z_log_var_values = z_log_var.detach().numpy()

# Plot overlapping histograms for z_mean and z_log_var
plt.figure(figsize=(8, 6))

plt.hist(z_mean_values.flatten(), bins=30, color='blue', alpha=0.5, label='z_mean')
plt.hist(z_log_var_values.flatten(), bins=30, color='green', alpha=0.5, label='z_log_var')

plt.title('Overlapping Histogram of z_mean and z_log_var')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend(loc='upper right')

plt.show()

In [218]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

# Assuming you already have your data loaded as X, y
# X_tensor = torch.tensor(X, dtype=torch.float32)

# Create a DataLoader for batching
batch_size = 64
dataset = TensorDataset(X_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the model (using the VAE defined earlier)
input_dim = 8
latent_dim = 2
vae = VAE(input_dim, latent_dim)

# Choose the loss mode
# loss_function = VAE_Loss(mode='combined', alpha=0.5, gamma=0.1)  # Lower gamma to stabilize
# loss_function = VAE_Loss(mode='negative_kl', alpha=0.5, gamma=0.1)
loss_function = VAE_Loss(mode='standard')

# Define optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-6)

# Training loop parameters
epochs = 1500  # Number of training epochs

# Training loop
for epoch in range(epochs):
    epoch_loss = 0.0
    for batch_idx, (x_batch,) in enumerate(dataloader):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass through the VAE
        recon_x, z_mean, z_log_var = vae(x_batch)

        # Calculate the loss
        loss = 1/loss_function(recon_x, x_batch, z_mean, z_log_var)

        # Backpropagation and optimization step
        loss.backward()

        # Clip gradients to prevent them from exploding
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=5.0)

        # Optimizer step
        optimizer.step()

        # Accumulate the loss for the epoch
        epoch_loss += loss.item()

        # Print latent variable statistics to monitor during training
        #print(f"Batch {batch_idx+1}, z_mean: {z_mean.mean().item()}, z_log_var: {z_log_var.mean().item()}")

    # Print loss for each epoch
    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_epoch_loss:.4f}")

In [219]:
recon_x_trained, z_mean_trained, z_log_var_trained = vae(X_tensor)
# Calculate initial KL divergence
final_kl_divergence = -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()).item()
print(f"Final KL Divergence: {final_kl_divergence}")

In [220]:
# Extract the means and log variances from the latent space
z_mean_values_trained = z_mean_trained.detach().numpy()
z_log_var_values_trained = z_log_var_trained.detach().numpy()

# Plot overlapping histograms for z_mean and z_log_var
plt.figure(figsize=(8, 6))

plt.hist(z_mean_values_trained.flatten(), bins=30, color='blue', alpha=0.5, label='z_mean')
plt.hist(z_log_var_values_trained.flatten(), bins=30, color='green', alpha=0.5, label='z_log_var')

plt.title('Overlapping Histogram of z_mean and z_log_var after training')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend(loc='upper right')

plt.show()

In [221]:
np.savetxt('z_mean_trained.txt', z_mean_values_trained)
np.savetxt('z_log_var_trained.txt', z_log_var_values_trained)
np.savetxt('y.txt', y)
np.savetxt('recon_x_trained.txt', recon_x_trained.detach().numpy())