<a href="https://colab.research.google.com/github/SaraAlinejad/vae_test_1/blob/main/untitled6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Task
Implement a variational autoencoder with an orthogonality term in its loss function using MLCOLVAR to extract collective variables from the provided Aib9 dihedral data in the ".npy" file.

## Install mlcolvar

### Subtask:
Install the necessary library for implementing the VAE.


**Reasoning**:
The subtask is to install the `mlcolvar` library. I will use the `pip install` command to install the library.



In [None]:
!pip install mlcolvar

Collecting mlcolvar
  Downloading mlcolvar-1.2.2-py3-none-any.whl.metadata (5.4 kB)
Collecting lightning (from mlcolvar)
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting KDEpy (from mlcolvar)
  Downloading kdepy-1.1.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning->mlcolvar)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting torchmetrics<3.0,>0.7.0 (from lightning->mlcolvar)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting pytorch-lightning (from lightning->mlcolvar)
  Downloading pytorch_lightning-2.5.5-py3-none-any.whl.metadata (20 kB)
Downloading mlcolvar-1.2.2-py3-none-any.whl (114 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.8/114.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading kdepy-1.1.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (686 kB)
[

## Load and prepare data

### Subtask:
Load and prepare data


**Reasoning**:
Load the numpy data and convert it to a PyTorch tensor, ensuring the correct data type.



In [None]:
import numpy as np
import torch

# Load the data from the .npy file
data_np = np.load('/content/no_repe_sin_cos_run2_data_2us.npy')

# Convert the numpy array to a PyTorch tensor
# Ensure the data type is suitable for VAE (typically float32 or float64)
data_tensor = torch.from_numpy(data_np).float()

# Display the shape and data type of the tensor
print("Shape of the data tensor:", data_tensor.shape)
print("Data type of the data tensor:", data_tensor.dtype)

Shape of the data tensor: torch.Size([2000000, 36])
Data type of the data tensor: torch.float32


## Define vae model

### Subtask:
Define the variational autoencoder model architecture using MLCOLVAR.


**Reasoning**:
Define the encoder and decoder architectures and instantiate the VAE model.



In [None]:
from mlcolvar.cvs import VariationalAutoEncoderCV as VAE
from mlcolvar.core.nn import FeedForward

# Define the latent dimension
latent_dim = 2

# Define the encoder architecture (list of layer dimensions)
encoder_layers = [36, 64, 32, latent_dim * 2]

# Define the decoder architecture (list of layer dimensions)
decoder_layers = [latent_dim, 32, 64, 36]

# Instantiate the VAE model
model = VAE(latent_dim, encoder_layers, decoder_layers)

print("VAE model architecture defined.")
print(model)

TypeError: 'FeedForward' object is not subscriptable

## Train VAE

### Subtask:
Train the VAE model with the modified loss function.

**Reasoning**:
Train the VAE model using the prepared data and the custom VAE loss function with the orthogonality term. This involves setting up an optimizer and iterating through the data for a specified number of epochs.

In [None]:
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Define training parameters
epochs = 100
batch_size = 256
learning_rate = 1e-3

# Create a DataLoader for the data
dataset = TensorDataset(data_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
print("Starting VAE training...")
for epoch in range(epochs):
    total_train_loss = 0
    total_recon_loss = 0
    total_kl_loss = 0
    total_ortho_loss = 0

    for batch in dataloader:
        x = batch[0]

        # Forward pass
        recon_x, mu, logvar = model(x)

        # Calculate loss
        loss, recon_loss, kl_loss, ortho_loss = vae_ortho_loss(recon_x, x, mu, logvar, model, ortho_coeff=0.1)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_ortho_loss += ortho_loss.item()

    # Print epoch loss
    print(f"Epoch [{epoch+1}/{epochs}], "
          f"Loss: {total_train_loss/len(dataloader):.4f}, "
          f"Recon Loss: {total_recon_loss/len(dataloader):.4f}, "
          f"KL Loss: {total_kl_loss/len(dataloader):.4f}, "
          f"Ortho Loss: {total_ortho_loss/len(dataloader):.4f}")

print("VAE training finished.")

In [None]:
import torch.nn.functional as F

def vae_ortho_loss(recon_x, x, mu, logvar, model, ortho_coeff=0.1):
    """
    Calculates the VAE loss with an added orthogonality term.

    Args:
        recon_x: Reconstructed input.
        x: Original input.
        mu: Mean of the latent distribution.
        logvar: Log variance of the latent distribution.
        model: The VAE model.
        ortho_coeff: Coefficient for the orthogonality term.

    Returns:
        The total loss (reconstruction loss + KL divergence + orthogonality term).
    """
    # Reconstruction loss (e.g., Mean Squared Error)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')

    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Orthogonality term
    # Get the encoded representation
    encoded = model.encode(x)
    # Calculate the covariance matrix of the encoded representation
    cov_matrix = torch.cov(encoded.T)
    # The orthogonality term is the sum of the absolute values of the off-diagonal elements
    ortho_loss = torch.sum(torch.triu(torch.abs(cov_matrix), diagonal=1))

    # Total loss
    total_loss = recon_loss + kl_loss + ortho_coeff * ortho_loss

    return total_loss, recon_loss, kl_loss, ortho_loss

print("Custom VAE loss function with orthogonality term defined.")

Custom VAE loss function with orthogonality term defined.


**Reasoning**:
The previous attempt to import `mlcolvar.vae` failed. I need to find the correct import path for the VAE class within the installed mlcolvar library. I will try importing directly from `mlcolvar`.

