### Biosynthetic Gene Clusters (BGCs) design using Attention-Based Conditional Variational Autoencoder (cVAE)

The model is conditioned on the class of the BGC to which the protein or domain belongs (e.g., Class 1, Class 2).

In [1]:
#Loading packages
#Importing packages

import Bio
print(Bio.__version__)
from Bio import SeqIO  # Import SeqIO from the Bio package
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import OneHotEncoder
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.linalg import sqrtm
import time
print(torch.__version__)


1.83
2.5.1


In [2]:
# Define the path to the embeddings file
file_path = "/Users/josephtsenum/Documents/PHA6935_AI_for_Drug_Discovery/Project/hybrid_pfam_esm_embeddings.npy"

# Load the embeddings
integrated_embeddings = np.load(file_path, allow_pickle=True)

# Inspect the loaded data
print("Data Type:", type(integrated_embeddings))
print("Data Shape:", integrated_embeddings.shape if hasattr(integrated_embeddings, 'shape') else "No shape attribute")
print("Sample Data:", integrated_embeddings[0] if isinstance(integrated_embeddings, (list, np.ndarray)) else integrated_embeddings)


Data Type: <class 'numpy.ndarray'>
Data Shape: (19450, 20372)
Sample Data: [ 0.          0.          0.         ... -0.08841293 -0.02668052
  0.07184622]


In [3]:
### Data preparation

import torch

# Convert NumPy array to PyTorch tensor
embedding_data = torch.tensor(integrated_embeddings, dtype=torch.float32)

# Inspect the tensor
print("Tensor Shape:", embedding_data.shape)


Tensor Shape: torch.Size([19450, 20372])


### Encoder

Concatenate the condition (e.g., BGC class) with the input embeddings. This ensures the latent space is informed by the condition.

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, condition_dim, latent_dim):
        super(Encoder, self).__init__()
        
        # 1D CNN layers for feature extraction
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        
        # Fully connected layers for latent space
        self.fc1 = nn.Linear(32 * input_dim + condition_dim, 1024)
        self.fc2_mean = nn.Linear(1024, latent_dim)
        self.fc2_logvar = nn.Linear(1024, latent_dim)

    def forward(self, x, condition):
        # x shape: (batch_size, input_dim)
        
        # Expand dimensions for 1D CNN: (batch_size, 1, input_dim)
        x = x.unsqueeze(1)
        
        # Pass through 1D CNN layers
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        
        # Flatten the output for fully connected layers
        x = self.flatten(x)  # Shape: (batch_size, 32 * input_dim)
        
        # Concatenate input embeddings with condition vector
        x = torch.cat((x, condition), dim=1)  # Shape: (batch_size, 32 * input_dim + condition_dim)
        
        # Pass through fully connected layers to get mean and logvar
        x = self.relu(self.fc1(x))
        mean = self.fc2_mean(x)
        logvar = self.fc2_logvar(x)
        return mean, logvar


### Decoder
Concatenate the condition with the latent vector before decoding.

In [5]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, condition_dim, output_dim, sequence_length):
        super(Decoder, self).__init__()
        
        # RNN Layer (LSTM)
        self.rnn = nn.LSTM(
            input_size=latent_dim + condition_dim, 
            hidden_size=latent_dim, 
            num_layers=1, 
            batch_first=True
        )
        
        # Self-attention mechanism
        self.attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=4, batch_first=True)
        
        # Fully connected layer for output
        self.fc_output = nn.Linear(latent_dim, output_dim)
        
        # Activation function
        self.sigmoid = nn.Sigmoid()
        
        # Sequence length for reconstructing the input
        self.sequence_length = sequence_length

    def forward(self, z, condition):
        # Concatenate latent vector and condition
        z_c = torch.cat((z, condition), dim=1)  # Shape: (batch_size, latent_dim + condition_dim)
        
        # Expand to match sequence length: (batch_size, sequence_length, latent_dim + condition_dim)
        z_c = z_c.unsqueeze(1).repeat(1, self.sequence_length, 1)
        
        # Pass through RNN
        rnn_output, _ = self.rnn(z_c)  # Shape: (batch_size, sequence_length, latent_dim)
        
        # Apply self-attention
        attn_output, _ = self.attention(rnn_output, rnn_output, rnn_output)  # Self-attention: (batch_size, sequence_length, latent_dim)
        
        # Map attention output to embedding space
        output = self.sigmoid(self.fc_output(attn_output))  # Shape: (batch_size, sequence_length, output_dim)
        return output


### Updating the cVAE Class

Pass the condition into both the encoder and decoder during the forward pass. Updated cVAE class structure with the self-attention-enabled decoder and the 1D CNN encoder

In [6]:
class cVAE(nn.Module):
    def __init__(self, input_dim, condition_dim, latent_dim, output_dim, sequence_length):
        super(cVAE, self).__init__()
        # Encoder: 1D CNN-based Encoder
        self.encoder = Encoder(input_dim, condition_dim, latent_dim)
        
        # Decoder: RNN with self-attention
        self.decoder = Decoder(latent_dim, condition_dim, output_dim, sequence_length)

    def forward(self, x, condition):
        # Encode the input to obtain mean and log variance
        mean, logvar = self.encoder(x, condition)
        
        # Reparameterization trick: sample latent vector z
        std = torch.exp(0.5 * logvar)
        z = mean + std * torch.randn_like(std)
        
        # Decode the latent vector
        reconstructed = self.decoder(z, condition)
        
        return reconstructed, mean, logvar


### Preparing the Conditioning Variable

Since our condition is categorical (BGC class), we will one-hot encode it.

In [7]:
# Ensure the number of BGC classes matches the number of rows in embedding_data
num_samples = embedding_data.shape[0]  # Number of rows in embedding_data

# Define your BGC class labels (you can adjust the actual classes as needed)
bgc_classes = ['Class1', 'Class2', 'Class3'] * (num_samples // 3) + ['Class1'] * (num_samples % 3)

# Verify the length of bgc_classes matches num_samples
assert len(bgc_classes) == num_samples, "bgc_classes length does not match embedding_data"

print("Updated BGC Classes Length:", len(bgc_classes))
print("Sample BGC Classes:", bgc_classes[:10])  # Display a sample of the classes


Updated BGC Classes Length: 19450
Sample BGC Classes: ['Class1', 'Class2', 'Class3', 'Class1', 'Class2', 'Class3', 'Class1', 'Class2', 'Class3', 'Class1']


In [8]:
# OneHotEncode the corrected bgc_classes
encoder = OneHotEncoder(sparse_output=False)  # Create OneHotEncoder instance
bgc_classes_array = np.array(bgc_classes).reshape(-1, 1)  # Reshape for encoding
condition_data = torch.tensor(encoder.fit_transform(bgc_classes_array), dtype=torch.float32)  # Convert to tensor

# Verify the shape of condition_data
assert condition_data.shape[0] == embedding_data.shape[0], "Condition data shape mismatch"
print("Condition Data Shape:", condition_data.shape)


Condition Data Shape: torch.Size([19450, 3])


In [9]:
print("Embedding Data Shape:", embedding_data.shape)
print("Number of BGC Classes:", len(bgc_classes))


Embedding Data Shape: torch.Size([19450, 20372])
Number of BGC Classes: 19450


In [10]:
# Number of embeddings
num_embeddings = embedding_data.shape[0]

# Define the unique BGC classes
classes = ['Class1', 'Class2', 'Class3']
num_classes = len(classes)

# Create the list of BGC class labels by cycling through the classes
bgc_classes = [classes[i % num_classes] for i in range(num_embeddings)]

# One-hot encode the BGC classes
encoder = OneHotEncoder(sparse_output=False)  # Create the encoder
condition_data = torch.tensor(
    encoder.fit_transform(np.array(bgc_classes).reshape(-1, 1)),  # One-hot encoding
    dtype=torch.float32  # Convert to PyTorch tensor
)

# Verify that the condition data matches the embedding data
assert condition_data.shape[0] == embedding_data.shape[0], "Mismatch between embeddings and conditions"
print("Condition Data Shape:", condition_data.shape)


Condition Data Shape: torch.Size([19450, 3])


### Updating Training Loop

Pass the condition into the forward pass of the cVAE.

In [11]:
# Reload embedding_data and condition_data if necessary
if isinstance(embedding_data, np.ndarray):
    embedding_data = torch.tensor(embedding_data, dtype=torch.float32)
if isinstance(condition_data, np.ndarray):
    condition_data = torch.tensor(condition_data, dtype=torch.float32)

# Ensure shapes are correct
print(f"Embedding Data Shape: {embedding_data.shape}")
print(f"Condition Data Shape: {condition_data.shape}")


Embedding Data Shape: torch.Size([19450, 20372])
Condition Data Shape: torch.Size([19450, 3])


In [12]:
# Define input, latent, and output dimensions
input_dim = embedding_data.shape[1]  # Number of features in embedding_data
latent_dim = 256  # Dimension of latent space
output_dim = input_dim  # Same as input_dim for reconstruction task
condition_dim = condition_data.shape[1]  # Dimension of the condition vector

# Instantiate the cVAE
vae = cVAE(input_dim, condition_dim, latent_dim, output_dim, sequence_length=input_dim)
print(f"Model instantiated with input_dim={input_dim}, condition_dim={condition_dim}, latent_dim={latent_dim}, output_dim={output_dim}")


Model instantiated with input_dim=20372, condition_dim=3, latent_dim=256, output_dim=20372


### Defining loss_function and optimizer

In [13]:
# Define the loss function
def loss_function(reconstructed, original, mean, logvar):
    # Reconstruction loss (MSE)
    reconstruction_loss = nn.MSELoss()(reconstructed, original)
    
    # KL divergence loss
    kl_divergence = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    
    return reconstruction_loss + kl_divergence

# Define the optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

print("Loss function and optimizer defined successfully.")


Loss function and optimizer defined successfully.


In [14]:
batch_size = 16  # Try a smaller batch size


In [15]:
print(f"Current Batch Size: {batch_size}")


Current Batch Size: 16


### Training

In [1]:
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for epoch in range(epochs):
    vae.train()
    epoch_loss = 0

    for i in range(0, embedding_data.shape[0], batch_size):
        batch = embedding_data[i:i + batch_size].to(device)
        condition_batch = condition_data[i:i + batch_size].to(device)

        optimizer.zero_grad()

        with autocast():  # Use mixed precision
            reconstructed, mean, logvar = vae(batch, condition_batch)
            loss = loss_function(reconstructed, batch, mean, logvar)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()


The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


  scaler = GradScaler()


NameError: name 'epochs' is not defined

### Evaluating the Training Loss

Plotting the training loss over epochs to understand how well the model converged.

In [None]:
# Plot training loss
plt.figure()
plt.plot(range(1, len(epoch_losses) + 1), epoch_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss over Epochs")
plt.legend()
plt.show()


### Visualizing the Latent Space

To evaluate the latent representations, we will:

1. Encode the input embeddings using the encoder.
2. Reduce the latent space to 2D using dimensionality reduction techniques, t-SNE.
3. Visualize the points, optionally color-coded by their conditions (BGC class).

In [None]:
# Pass embeddings through the encoder
vae.eval()
with torch.no_grad():
    latent_vectors = []
    for i in range(embedding_data.shape[0]):
        embedding = embedding_data[i].unsqueeze(0)  # Add batch dimension
        condition = condition_data[i].unsqueeze(0)  # Add batch dimension
        mean, _ = vae.encoder(embedding, condition)
        latent_vectors.append(mean.numpy().flatten())

latent_vectors = np.array(latent_vectors)

# Use t-SNE for dimensionality reduction to 2D
tsne = TSNE(n_components=2, random_state=42)
latent_2d = tsne.fit_transform(latent_vectors)

# Plot the reduced latent space
plt.figure(figsize=(8, 6))
plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=np.argmax(condition_data.numpy(), axis=1), cmap='viridis', alpha=0.7)
plt.colorbar(label="Condition (BGC Class)")
plt.title("Latent Space Visualization")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.show()


### Reconstructing Inputs

Evaluate how well the decoder can reconstruct the input embeddings given the latent representations and conditions. We will measure reconstruction quality and visualize reconstructed embeddings.

In [None]:
# Select a few random samples to reconstruct
vae.eval()
samples = embedding_data[:5]  # Replace with random indices if needed
conditions = condition_data[:5]

# Perform reconstruction
with torch.no_grad():
    reconstructed, _, _ = vae(samples, conditions)

# Compare original and reconstructed embeddings
for i in range(samples.shape[0]):
    print(f"Sample {i + 1}")
    print("Original:", samples[i].numpy()[:10])  # Show first 10 features
    print("Reconstructed:", reconstructed[i].numpy()[:10])  # Show first 10 features
    print()


In [None]:
# Ensure the model is in evaluation mode
vae.eval()

# Select a few random samples to reconstruct
num_samples = 5  # Number of samples to visualize
sample_indices = torch.randint(0, embedding_data.shape[0], (num_samples,))  # Random indices
samples = embedding_data[sample_indices]
conditions = condition_data[sample_indices]

# Perform reconstruction
with torch.no_grad():
    reconstructed, _, _ = vae(samples, conditions)

# Plot original and reconstructed embeddings for each sample
for i in range(num_samples):
    plt.figure(figsize=(10, 6))
    
    # Plot original embedding
    plt.plot(samples[i].numpy(), label="Original", alpha=0.7, linewidth=2)
    
    # Plot reconstructed embedding
    plt.plot(reconstructed[i].numpy(), label="Reconstructed", alpha=0.7, linewidth=2, linestyle="--")
    
    # Add labels and legend
    plt.title(f"Sample {i + 1}: Original vs Reconstructed Embedding")
    plt.xlabel("Feature Index")
    plt.ylabel("Value")
    plt.legend()
    plt.tight_layout()
    plt.show()


### Predict New Domain Configurations

To generate new domain configurations or embeddings, we will sample from the latent space and use the decoder.

Generate New Configurations:

1. Samples latent vectors from a Gaussian distribution (latent_samples).
2. Uses the decoder to generate new domain embeddings based on random conditions.

In [None]:
# Sample latent vectors from a Gaussian distribution
num_samples = 5
latent_samples = torch.randn((num_samples, latent_dim))

# Use the decoder with random conditions (e.g., one-hot encoded random BGC class)
random_conditions = condition_data[:num_samples]  # Replace with specific conditions if needed
generated = vae.decoder(latent_samples, random_conditions)

# Print or analyze generated embeddings
print("Generated Embeddings:")
print(generated)


In [None]:
# Generate new domain configurations by sampling from the latent space
vae.eval()  # Ensure the model is in evaluation mode
num_samples = 5  # Number of new samples to generate

# Sample latent vectors from a Gaussian distribution
latent_samples = torch.randn((num_samples, latent_dim))

# Use random conditions for generation
random_conditions = condition_data[:num_samples]  # Select the first few conditions

# Generate new embeddings using the decoder
with torch.no_grad():
    generated_embeddings = vae.decoder(latent_samples, random_conditions)

# Plot the generated embeddings
for i in range(num_samples):
    plt.figure(figsize=(8, 4))
    plt.plot(generated_embeddings[i].numpy(), label=f"Generated Sample {i + 1}", alpha=0.7, linewidth=2)
    plt.title(f"Generated Embedding {i + 1}")
    plt.xlabel("Feature Index")
    plt.ylabel("Value")
    plt.legend()
    plt.tight_layout()
    plt.show()


### Save Results

Saving the trained model, latent representations, or reconstructed outputs for further use or analysis.



In [None]:
# Save the trained model
torch.save(vae.state_dict(), "trained_cvae_model_50_epochs.pth")


In [None]:
# Save latent vectors and conditions for further analysis
np.save("latent_vectors.npy", latent_vectors)
np.save("latent_conditions.npy", condition_data.numpy())


### Evaluation with Fréchet Inception Distance (FID) Score

Fréchet Inception Distance (FID) is a metric for evaluating generative models like a cVAE, especially for assessing the quality of generated data (e.g., embeddings or domain configurations). The FID score compares the distributions of the real and generated data in a feature space, and lower FID scores indicate that the generated data is closer to the real data.

Why FID Score?

1. Reconstruction-Based Tasks: FID measures how similar the generated embeddings (or reconstructions) are to the real ones in terms of distribution.
2. Robustness: It captures not just pointwise similarity (like MSE) but also the overall quality and diversity of generated samples.
3. Widely Used: FID is commonly used in generative modeling tasks (e.g., GANs, VAEs) for evaluating the fidelity and diversity of generated data.


In [None]:
vae = cVAE(input_dim, condition_dim, latent_dim, output_dim)
vae.load_state_dict(torch.load("trained_cvae_model_50_epochs.pth"))
vae.eval()  # Set to evaluation mode


In [None]:
print("Embedding Data Shape:", embedding_data.shape)
print("Condition Data Shape:", condition_data.shape)


#### Compute FID Score

Extracting Real and Generated Embeddings:
1. Pass real embeddings (from embedding_data) and generated embeddings (from the decoder) through a feature extractor (e.g., an Inception model).
2. Collect the activations (intermediate feature representations) for both sets of data.

In [None]:

# Generate reconstructions
vae.eval()
with torch.no_grad():
    reconstructed, _, _ = vae(embedding_data, condition_data)

# Use the original embeddings (real) and reconstructed embeddings (generated)
real_features = embedding_data.numpy()  # Real embeddings
generated_features = reconstructed.numpy()  # Generated embeddings


### Calculating the Mean and Covariance

Compute the mean and covariance for both the real and generated embeddings.


In [None]:
def calculate_statistics(features):
    """Calculate mean and covariance of the features."""
    mean = np.mean(features, axis=0)
    cov = np.cov(features, rowvar=False)
    return mean, cov

# Compute statistics for real and generated features
real_mean, real_cov = calculate_statistics(real_features)
gen_mean, gen_cov = calculate_statistics(generated_features)


In [None]:
# Plot the means for comparison
plt.figure(figsize=(10, 6))
plt.plot(real_mean, label="Real Embeddings Mean", alpha=0.7, linewidth=2)
plt.plot(gen_mean, label="Generated Embeddings Mean", alpha=0.7, linewidth=2, linestyle="--")
plt.title("Mean of Real vs Generated Embeddings")
plt.xlabel("Feature Index")
plt.ylabel("Mean Value")
plt.legend()
plt.tight_layout()
plt.show()

### Computing the FID Score

The FID score is calculated as:

$$
FID = ||\mu_1 - \mu_2||^2 + \text{Tr}(C_1 + C_2 - 2 \sqrt{C_1 C_2})
$$

Where:

- **$\mu_1$**: Mean of the real embeddings (feature distribution of real data).
- **$\mu_2$**: Mean of the generated embeddings (feature distribution of generated data).
- **$C_1$**: Covariance matrix of the real embeddings.
- **$C_2$**: Covariance matrix of the generated embeddings.
- **$\text{Tr}$**: Trace of a matrix (sum of its diagonal elements).
- **$||\mu_1 - \mu_2||^2$**: Squared Euclidean distance between the means of the real and generated embeddings.
- **$\sqrt{C_1 C_2}$**: Matrix square root of the product of the covariance matrices.


### Generate Reconstructions
Use your trained cVAE model to reconstruct the embeddings based on the input embedding_data and condition_data.

In [None]:
vae.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation for evaluation
    reconstructed, _, _ = vae(embedding_data, condition_data)  # Generate reconstructed embeddings


### Define Real and Generated Features
Convert both the original embeddings (embedding_data) and the reconstructed embeddings (reconstructed) to NumPy for further analysis.

In [None]:
# Convert embeddings to NumPy
real_features = embedding_data.numpy()  # Real embeddings
generated_features = reconstructed.numpy()  # Generated embeddings

# Verify shapes
print("Real Features Shape:", real_features.shape)
print("Generated Features Shape:", generated_features.shape)


### Reduce Dimensionality with PCA
Perform PCA to reduce the dimensionality of the real and generated embeddings. This helps make covariance computations faster and more stable.

In [None]:
from sklearn.decomposition import PCA

# Apply PCA to reduce dimensionality
pca = PCA(n_components=512)  # Adjust the number of components if necessary
real_features_reduced = pca.fit_transform(real_features)
gen_features_reduced = pca.transform(generated_features)

# Verify shapes after PCA
print("Reduced Real Features Shape:", real_features_reduced.shape)
print("Reduced Generated Features Shape:", gen_features_reduced.shape)


### Compute Mean and Covariance
Calculate the mean and covariance of the reduced embeddings using the previously defined calculate_statistics function.

In [None]:
def calculate_statistics(features):
    """Calculate mean and covariance of the features."""
    mean = np.mean(features, axis=0)
    cov = np.cov(features, rowvar=False)
    return mean, cov

# Compute statistics for real and generated features
real_mean, real_cov = calculate_statistics(real_features_reduced)
gen_mean, gen_cov = calculate_statistics(gen_features_reduced)

# Print the computed statistics
print("Real Mean Shape:", real_mean.shape)
print("Real Covariance Shape:", real_cov.shape)
print("Generated Mean Shape:", gen_mean.shape)
print("Generated Covariance Shape:", gen_cov.shape)


### Compute FID Score
Using the computed mean and covariance, calculate the FID score:

In [None]:
from scipy.linalg import sqrtm

def calculate_fid(real_mean, real_cov, gen_mean, gen_cov):
    """Calculate the Fréchet Inception Distance (FID)."""
    mean_diff = np.sum((real_mean - gen_mean) ** 2)  # Squared Euclidean distance between means
    cov_sqrt = sqrtm(real_cov @ gen_cov)  # Matrix square root of product of covariances
    if np.iscomplexobj(cov_sqrt):
        cov_sqrt = cov_sqrt.real  # Handle numerical instability
    fid = mean_diff + np.trace(real_cov + gen_cov - 2 * cov_sqrt)  # FID formula
    return fid

# Compute FID score
fid_score = calculate_fid(real_mean, real_cov, gen_mean, gen_cov)
print(f"FID Score: {fid_score:.4f}")


### Visualize Results
Plot the means and covariance diagonal to compare the distributions of real and generated embeddings.

In [None]:
import matplotlib.pyplot as plt

# Plot the means for comparison
plt.figure(figsize=(10, 6))
plt.plot(real_mean, label="Real Embeddings Mean", alpha=0.7, linewidth=2)
plt.plot(gen_mean, label="Generated Embeddings Mean", alpha=0.7, linewidth=2, linestyle="--")
plt.title("Mean of Real vs Generated Embeddings")
plt.xlabel("Feature Index")
plt.ylabel("Mean Value")
plt.legend()
plt.tight_layout()
plt.show()


### Plot Covariance Diagonal Comparison

In [None]:
# Plot the diagonal of the covariance matrices for comparison
plt.figure(figsize=(10, 6))
plt.plot(np.diag(real_cov), label="Real Embeddings Covariance (Diagonal)", alpha=0.7, linewidth=2)
plt.plot(np.diag(gen_cov), label="Generated Embeddings Covariance (Diagonal)", alpha=0.7, linewidth=2, linestyle="--")
plt.title("Covariance (Diagonal) of Real vs Generated Embeddings")
plt.xlabel("Feature Index")
plt.ylabel("Variance Value")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# Inspect a random sample of embeddings
print("Sample Embedding:", embedding_data[0].numpy())
print("Number of Zeros in Sample Embedding:", (embedding_data[0] == 0).sum().item())


def calculate_fid(real_mean, real_cov, gen_mean, gen_cov):
    """Calculate the Fréchet Inception Distance (FID)."""
    mean_diff = np.sum((real_mean - gen_mean) ** 2)
    cov_sqrt = sqrtm(real_cov @ gen_cov)  # Matrix square root
    # Handle numerical errors
    if np.iscomplexobj(cov_sqrt):
        cov_sqrt = cov_sqrt.real
    fid = mean_diff + np.trace(real_cov + gen_cov - 2 * cov_sqrt)
    return fid

# Compute FID
fid_score = calculate_fid(real_mean, real_cov, gen_mean, gen_cov)
print(f"FID Score: {fid_score:.4f}")


start_time = time.time()
fid_score = calculate_fid(real_mean, real_cov, gen_mean, gen_cov)
end_time = time.time()

print(f"FID Score: {fid_score:.4f}")
print(f"Time Taken: {end_time - start_time:.4f} seconds")


### Interpret the FID Score
Lower is Better: A lower FID score means the generated embeddings are closer to the real ones.

Thresholds:

FID ~10: High-quality generations.

FID > 50: Poor-quality generations (distribution mismatch).

Advantages of FID over MSE

1. Captures Distributional Similarity: Unlike MSE, which evaluates individual points, FID compares entire distributions.
2. Handles Diversity: FID penalizes lack of diversity in generated data.
3. Task-Agnostic: FID works well for tasks where direct reconstruction is not the only goal (e.g., generating new domain configurations or embeddings).

### Time Complexity

#### How Much Time Will FID Calculation Take?

The time required to calculate the **Fréchet Inception Distance (FID)** using the provided function depends on several factors:

---

#### **Factors Influencing Time**
1. **Dimensionality of Features (Number of Features)**:
   - Higher-dimensional feature spaces (e.g., thousands of features) will take more time because matrix operations (like covariance computation and square root calculation) become more computationally expensive.

2. **Matrix Square Root (`sqrtm`)**:
   - This is the most computationally expensive operation in the function. It has a complexity of \(O(d^3)\), where \(d\) is the dimensionality of the covariance matrices.

3. **Hardware**:
   - The CPU or GPU you are using significantly impacts the computation time.

---

#### **Time Complexity**
Assuming \(d\) is the dimensionality of the embeddings (e.g., 2048 for Inception-based features):
- **Matrix Square Root (`sqrtm`)**: \(O(d^3)\)
- **Matrix Multiplication (`real_cov @ gen_cov`)**: \(O(d^3)\)
- **Trace Calculation and Mean Squared Error**: \(O(d^2)\)

Thus, the overall time complexity is approximately **\(O(d^3)\)**.

---

#### **Empirical Estimation**
- For embeddings of size \(d = 2048\), typical in Inception-based features:
  - **Small Datasets (e.g., 10,000 samples)**: A few seconds to a minute on a modern CPU.
  - **Larger Datasets**: Time increases proportionally with the number of samples due to covariance computation.

---

#### **Optimizing the Calculation**
1. **Use GPU Acceleration**:
   - Libraries like PyTorch or TensorFlow can handle matrix operations on GPUs more efficiently.

2. **Precompute Covariance**:
   - If you're comparing multiple generated sets to the same real data, precomputing the real data statistics (mean and covariance) can save time.

---

#### **Benchmarking FID Calculation**
You can measure the time taken to compute FID in your environment using Python’s `time` module:

```python
import time

start_time = time.time()
fid_score = calculate_fid(real_mean, real_cov, gen_mean, gen_cov)
end_time = time.time()

print(f"FID Score: {fid_score:.4f}")
print(f"Time Taken: {end_time - start_time:.4f} seconds")
