# üß¨ AI4CellFate: Interpretable Cell Fate Prediction

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ComputationalMicroscopy4CellBio/AI4CellFate/blob/main/notebooks/Codeless_AI4CellFate_Google_Colab.ipynb)

**AI4CellFate** is a deep learning framework for predicting cell fate from single-frame microscopy images with full interpretability. This notebook provides an interface to train and apply AI4CellFate to your own data.

---

## üìã **Quick Start Guide**

### **Step 1: Prepare Your Data**
- **Images**: Single-channel images (20√ó20 pixels recommended)
- **Labels**: Binary classification labels (0 and 1)
- **Format**: NumPy arrays (.npy files)
- **‚ö†Ô∏è IMPORTANT**: Normalise your images to [0,1] range

Then, execute the cells in order - the code is hidden for simplicity.

### **Step 2: Upload Data**
Upload your data files to this Colab session or connect to Google Drive

### **Step 3: Train the Model**
Run the cell that trains the AI4CellFate model.

### **Step 4: Interpret Results**
View latent space visualisations and feature interpretations

---


In [None]:
#@title üîß **Setup and Installation** {display-mode: "form"}
#@markdown This cell installs AI4CellFate and required dependencies. Run this first!

import os
import sys
import subprocess

# Clone the AI4CellFate repository
if not os.path.exists('AI4CellFate'):
    !git clone https://github.com/ComputationalMicroscopy4CellBio/AI4CellFate.git
    
# Change to the repository directory
os.chdir('AI4CellFate')
sys.path.append('/content/AI4CellFate')

# Install package-specific requirements
%pip install -q tensorflow opencv-python scikit-learn matplotlib seaborn scipy

# Import necessary modules
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Import AI4CellFate modules
from src.training.train import train_autoencoder, train_cellfate
from src.models.classifier import mlp_classifier
from src.models import Encoder, Decoder, Discriminator
from src.preprocessing.preprocessing_functions import augment_dataset, augmentations

print("‚úÖ AI4CellFate successfully installed and imported!")
print("üìä TensorFlow version:", tf.__version__)
#print("üî• GPU available:", "Yes" if tf.config.list_physical_devices('GPU') else "No")


In [None]:
#@title üìÅ **Data Upload and Preparation** {display-mode: "form"}

#@markdown ## üìã **Data Preparation Guidelines**
#@markdown Before uploading, ensure your data meets these requirements:
#@markdown
#@markdown ### **‚úÖ Required Format:**
#@markdown - **Images**: NumPy arrays (.npy) with shape `[n_samples, height, width]`
#@markdown - **Labels**: NumPy arrays (.npy) with binary labels (0 and 1)
#@markdown - **Normalisation**: Images must be normalised to [0, 1] range
#@markdown - **Splits**: Provide separate train, validation, and test sets
#@markdown - **Balance**: Training set should be balanced (equal samples per class)
#@markdown - **Augmentation**: Training set should preferably be augmented (5x recommended)
#@markdown
#@markdown ### **üìÇ Expected Files:**
#@markdown - `x_train.npy` / `x_train_aug.npy` (training images)
#@markdown - `y_train.npy` / `y_train_aug.npy` (training labels)  
#@markdown - `x_val.npy` (validation images)
#@markdown - `y_val.npy` (validation labels)
#@markdown - `x_test.npy` (test images)
#@markdown - `y_test.npy` (test labels)
#@markdown
#@markdown ### **üîß Preprocessing Steps (Do Before Upload):**
#@markdown 1. **Normalise** images: `images = images / images.max()` or `images = (images - images.min()) / (images.max() - images.min())`
#@markdown 2. **Split** data: 60% train, 20% validation, 20% test (stratified)
#@markdown 3. **Balance** training set: Equal number of samples per class
#@markdown 4. **Augment** training set: Apply rotations, flips (5x augmentation recommended)
#@markdown 5. **Select frame**: If time-lapse data, extract single frame per cell `[n_samples, height, width]`

from google.colab import drive
import numpy as np

#@markdown ---
#@markdown ## üì§ **Upload Method**
upload_method = "drive" #@param ["drive", "upload"]

if upload_method == "drive":
    # Mount Google Drive
    drive.mount('/content/drive')
    
    #@markdown ### **üìÅ Google Drive File Paths**
    #@markdown Specify the full paths to your prepared data files:
    
    x_train_path = "/content/drive/MyDrive/AI4CellFate_data/x_train_aug.npy" #@param {type:"string"}
    y_train_path = "/content/drive/MyDrive/AI4CellFate_data/y_train_aug.npy" #@param {type:"string"}
    x_val_path = "/content/drive/MyDrive/AI4CellFate_data/x_val.npy" #@param {type:"string"}
    y_val_path = "/content/drive/MyDrive/AI4CellFate_data/y_val.npy" #@param {type:"string"}
    x_test_path = "/content/drive/MyDrive/AI4CellFate_data/x_test.npy" #@param {type:"string"}
    y_test_path = "/content/drive/MyDrive/AI4CellFate_data/y_test.npy" #@param {type:"string"}
    
    # Load data from Drive
    print("üìÅ Loading data from Google Drive...")
    try:
        x_train_aug = np.load(x_train_path)
        y_train_aug = np.load(y_train_path)
        x_val = np.load(x_val_path)
        y_val = np.load(y_val_path)
        x_test = np.load(x_test_path)
        y_test = np.load(y_test_path)
        print("‚úÖ All files loaded successfully!")
    except Exception as e:
        print(f"‚ùå Error loading files: {e}")
        print("Please check your file paths and ensure files exist.")
        raise

else:
    # Upload files directly
    from google.colab import files
    
    print("üì§ Upload your prepared data files (.npy format)")
    print("Expected files: x_train_aug.npy, y_train_aug.npy, x_val.npy, y_val.npy, x_test.npy, y_test.npy")
    
    uploaded = files.upload()
    
    # Load uploaded files
    try:
        # Find and load files based on naming patterns
        train_x_file = next((f for f in uploaded.keys() if 'train' in f.lower() and ('x_' in f.lower() or 'image' in f.lower())), None)
        train_y_file = next((f for f in uploaded.keys() if 'train' in f.lower() and ('y_' in f.lower() or 'label' in f.lower())), None)
        val_x_file = next((f for f in uploaded.keys() if 'val' in f.lower() and ('x_' in f.lower() or 'image' in f.lower())), None)
        val_y_file = next((f for f in uploaded.keys() if 'val' in f.lower() and ('y_' in f.lower() or 'label' in f.lower())), None)
        test_x_file = next((f for f in uploaded.keys() if 'test' in f.lower() and ('x_' in f.lower() or 'image' in f.lower())), None)
        test_y_file = next((f for f in uploaded.keys() if 'test' in f.lower() and ('y_' in f.lower() or 'label' in f.lower())), None)
        
        if not all([train_x_file, train_y_file, val_x_file, val_y_file, test_x_file, test_y_file]):
            missing = []
            if not train_x_file: missing.append("training images")
            if not train_y_file: missing.append("training labels")
            if not val_x_file: missing.append("validation images")
            if not val_y_file: missing.append("validation labels")
            if not test_x_file: missing.append("test images")
            if not test_y_file: missing.append("test labels")
            raise ValueError(f"Missing required files: {', '.join(missing)}")
        
        x_train_aug = np.load(train_x_file)
        y_train_aug = np.load(train_y_file)
        x_val = np.load(val_x_file)
        y_val = np.load(val_y_file)
        x_test = np.load(test_x_file)
        y_test = np.load(test_y_file)
        
        print("‚úÖ All files loaded successfully!")
        
    except Exception as e:
        print(f"‚ùå Error loading files: {e}")
        raise

# Validate data format and requirements
print("\nüîç **Data Validation:**")

# Check shapes
print(f"üìä **Data Shapes:**")
print(f"   ‚Ä¢ Training: {x_train_aug.shape} images, {y_train_aug.shape} labels")
print(f"   ‚Ä¢ Validation: {x_val.shape} images, {y_val.shape} labels")
print(f"   ‚Ä¢ Test: {x_test.shape} images, {y_test.shape} labels")

# Check if images are 3D (samples, height, width)
if len(x_train_aug.shape) != 3:
    print(f"‚ö†Ô∏è  WARNING: Training images have shape {x_train_aug.shape}. Expected 3D: [samples, height, width]")
if len(x_val.shape) != 3:
    print(f"‚ö†Ô∏è  WARNING: Validation images have shape {x_val.shape}. Expected 3D: [samples, height, width]")
if len(x_test.shape) != 3:
    print(f"‚ö†Ô∏è  WARNING: Test images have shape {x_test.shape}. Expected 3D: [samples, height, width]")

# Check normalisation
train_min, train_max = x_train_aug.min(), x_train_aug.max()
val_min, val_max = x_val.min(), x_val.max()
test_min, test_max = x_test.min(), x_test.max()

print(f"\nüìè **Normalisation Check:**")
print(f"   ‚Ä¢ Training: [{train_min:.3f}, {train_max:.3f}]")
print(f"   ‚Ä¢ Validation: [{val_min:.3f}, {val_max:.3f}]")
print(f"   ‚Ä¢ Test: [{test_min:.3f}, {test_max:.3f}]")

if not (0 <= train_min and train_max <= 1 and 0 <= val_min and val_max <= 1 and 0 <= test_min and test_max <= 1):
    print("‚ùå WARNING: Images not properly normalised to [0,1] range!")
    print("Please normalise your images before proceeding.")
else:
    print("‚úÖ Images properly normalised to [0,1] range")

# Check class balance in training set
unique_train, counts_train = np.unique(y_train_aug, return_counts=True)
print(f"\n‚öñÔ∏è  **Training Set Balance:**")
for class_label, count in zip(unique_train, counts_train):
    print(f"   ‚Ä¢ Class {class_label}: {count} samples")

if len(counts_train) == 2 and abs(counts_train[0] - counts_train[1]) / max(counts_train) < 0.1:
    print("‚úÖ Training set is balanced")
else:
    print("‚ö†Ô∏è  WARNING: Training set appears unbalanced. Consider balancing before training.")

# Check binary labels
all_labels = np.concatenate([y_train_aug, y_val, y_test])
unique_labels = np.unique(all_labels)
if not np.array_equal(unique_labels, [0, 1]):
    print(f"‚ö†Ô∏è  WARNING: Labels contain values {unique_labels}. Expected binary labels [0, 1]")
else:
    print("‚úÖ Binary labels [0, 1] detected")

print(f"\nüéØ **Data loaded and validated! Ready for training.**")
print(f"\nüìã **Variables available for training:**")
print(f"   ‚Ä¢ x_train_aug: {x_train_aug.shape}")
print(f"   ‚Ä¢ y_train_aug: {y_train_aug.shape}")
print(f"   ‚Ä¢ x_val: {x_val.shape}")
print(f"   ‚Ä¢ y_val: {y_val.shape}")
print(f"   ‚Ä¢ x_test: {x_test.shape}")
print(f"   ‚Ä¢ y_test: {y_test.shape}")


---

## üìê **Image Size Requirements & Model Adaptation**

### **Default Configuration**
AI4CellFate is designed for **20x20 pixel grayscale images**. If your data matches this size, you can proceed directly to the next cell.

### **For Different Image Sizes**
If your images are larger (e.g., 64x64, 128x128), you'll need to modify the model architecture:

#### **Step-by-Step Guide:**

1. **Navigate to Model Files:**
   ```
   AI4CellFate/
   ‚îî‚îÄ‚îÄ src/
       ‚îî‚îÄ‚îÄ models/
           ‚îú‚îÄ‚îÄ encoder.py    ‚Üê Modify this
           ‚îî‚îÄ‚îÄ decoder.py    ‚Üê Modify this
   ```

2. **Update Encoder Architecture (`src/models/encoder.py`):**
   - Locate the `Conv2D` layers in the `__init__` method
   - **Add more layers** for larger images:
     - For 64x64: Add 1-2 additional Conv2D layers
     - For 128x128: Add 2-3 additional Conv2D layers
   - **Pattern**: Each Conv2D layer typically halves the spatial dimensions
   - **Example**: 128x128 ‚Üí 64x64 ‚Üí 32x32 ‚Üí 16x16 ‚Üí 8x8 ‚Üí 4x4 ‚Üí flatten

3. **Update Decoder Architecture (`src/models/decoder.py`):**
   - Mirror the encoder changes in reverse
   - Adjust the `Dense` layer input size to match encoder output
   - Use `Conv2DTranspose` layers to upscale back to original size

4. **Key Parameters to Adjust:**
   - `filters`: Number of feature maps (typically 32, 64, 128, 256)
   - `kernel_size`: Usually (3,3) or (4,4)
   - `strides`: Usually (2,2) for downsampling/upsampling
   - `padding`: Usually 'same' to maintain dimensions

#### **Quick Tips:**
- **Maintain symmetry**: Encoder downsampling should match decoder upsampling
- **Test incrementally**: Start with one additional layer pair
- **Monitor memory**: Larger images require more GPU memory
- **Adjust batch size**: You may need to reduce `batch_size` for larger images

#### **Example Modification:**
For 64x64 images, add one more Conv2D layer in encoder:
```python
# In encoder.py, add after existing Conv2D layers:
x = Conv2D(64, (3, 3), strides=(2, 2), padding='same')(x)  # 64‚Üí32
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
```

And corresponding Conv2DTranspose in decoder:
```python
# In decoder.py, add before existing Conv2DTranspose layers:
x = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(x)  # 32‚Üí64
```

---


In [None]:
#@title ‚öôÔ∏è **Model Configuration** {display-mode: "form"}
#@markdown Configure AI4CellFate training parameters

# Training parameters
latent_dim = 2 #@param {type:"integer"}
batch_size = 30 #@param {type:"integer"}
stage1_epochs = 35 #@param {type:"integer"}
stage2_epochs = 100 #@param {type:"integer"}
learning_rate = 0.001 #@param {type:"number"}
random_seed = 42 #@param {type:"integer"}

# Advanced parameters
gaussian_noise_std = 0.003 #@param {type:"number"}
lambda_recon_stage1 = 5 #@param {type:"integer"}
lambda_adv_stage1 = 1 #@param {type:"integer"}
lambda_recon_stage2 = 6 #@param {type:"integer"}
lambda_adv_stage2 = 4 #@param {type:"integer"}
lambda_cov = 1 #@param {type:"number"}
lambda_contra = 8 #@param {type:"integer"}

# Configuration dictionaries
config_stage1 = {
    'batch_size': batch_size,
    'epochs': stage1_epochs,
    'learning_rate': learning_rate,
    'seed': random_seed,
    'latent_dim': latent_dim,
    'GaussianNoise_std': gaussian_noise_std,
    'lambda_recon': lambda_recon_stage1,
    'lambda_adv': lambda_adv_stage1,
}

config_stage2 = {
    'batch_size': batch_size,
    'epochs': stage2_epochs,
    'learning_rate': learning_rate,
    'seed': random_seed,
    'latent_dim': latent_dim,
    'GaussianNoise_std': gaussian_noise_std,
    'lambda_recon': lambda_recon_stage2,
    'lambda_adv': lambda_adv_stage2,
    'lambda_cov': lambda_cov,
    'lambda_contra': lambda_contra,
}

print(f"‚öôÔ∏è **Model Configuration:**")
print(f"   ‚Ä¢ Latent dimensions: {latent_dim}")
print(f"   ‚Ä¢ Batch size: {batch_size}")
print(f"   ‚Ä¢ Stage 1 epochs: {stage1_epochs}")
print(f"   ‚Ä¢ Stage 2 epochs: {stage2_epochs}")
print(f"   ‚Ä¢ Learning rate: {learning_rate}")
print(f"   ‚Ä¢ Random seed: {random_seed}")
print(f"\nüîß **Advanced Parameters:**")
print(f"   ‚Ä¢ Gaussian noise std: {gaussian_noise_std}")
print(f"   ‚Ä¢ Lambda reconstruction (S1/S2): {lambda_recon_stage1}/{lambda_recon_stage2}")
print(f"   ‚Ä¢ Lambda adversarial (S1/S2): {lambda_adv_stage1}/{lambda_adv_stage2}")
print(f"   ‚Ä¢ Lambda covariance: {lambda_cov}")
print(f"   ‚Ä¢ Lambda contrastive: {lambda_contra}")

print("\n‚úÖ Model configuration completed!")


In [None]:
#@title üöÄ **AI4CellFate Model Training (Stage 1 + Stage 2)** {display-mode: "form"}
#@markdown Train the complete AI4CellFate model in two stages: (1) Adversarial Autoencoder, (2) Latent Space Engineering

print("üöÄ **AI4CellFate Model Training**")
print("   Training will proceed in two stages:")
print("   ‚Ä¢ Stage 1: Adversarial Autoencoder (reconstruction + Gaussian latent space)")
print("   ‚Ä¢ Stage 2: Latent Space Engineering (+ covariance + contrastive losses)")
print()

# =============================================================================
# STAGE 1: Adversarial Autoencoder Training
# =============================================================================
print("üîÑ **Stage 1: Training Adversarial Autoencoder**")
print("   Learning basic image reconstruction and Gaussian latent space...")

stage1_results = train_autoencoder(
    config_stage1, 
    x_train_aug, 
    x_val
)

# Extract trained models
encoder = stage1_results['encoder']
decoder = stage1_results['decoder']
discriminator = stage1_results['discriminator']

print("\n‚úÖ **Stage 1 Completed!**")
print(f"   ‚Ä¢ Final reconstruction loss: {stage1_results['recon_loss'][-1]:.4f}")
print(f"   ‚Ä¢ Final adversarial loss: {stage1_results['adv_loss'][-1]:.4f}")

# =============================================================================
# STAGE 2: AI4CellFate Training (Latent Space Engineering)
# =============================================================================
print("\nüîÑ **Stage 2: AI4CellFate Training (Latent Space Engineering)**")
print("   Adding covariance and contrastive losses for feature disentanglement...")

stage2_results = train_cellfate(
    config_stage2,
    encoder,
    decoder, 
    discriminator,
    x_train_aug,
    y_train_aug,
    x_val,
    y_val,
    x_test,
    y_test
)

# Update models with final trained versions
encoder = stage2_results['encoder']
decoder = stage2_results['decoder']
discriminator = stage2_results['discriminator']
final_confusion_matrix = stage2_results['confusion_matrix']

print("\nüéâ **Stage 2 Completed!**")
print(f"   ‚Ä¢ Final reconstruction loss: {stage2_results['recon_loss'][-1]:.4f}")
print(f"   ‚Ä¢ Final adversarial loss: {stage2_results['adv_loss'][-1]:.4f}")
print(f"   ‚Ä¢ Final covariance loss: {stage2_results['cov_loss'][-1]:.4f}")
print(f"   ‚Ä¢ Final contrastive loss: {stage2_results['contra_loss'][-1]:.4f}")
print(f"   ‚Ä¢ Training stopped at epochs: {stage2_results['good_conditions_stop']}")

# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print("\nüèÜ **Final Training Results:**")
print("   ‚Ä¢ Models trained successfully on both stages")
print("   ‚Ä¢ Latent space engineered for optimal feature disentanglement")
print("   ‚Ä¢ Ready for latent space visualization and interpretation")

# Display final confusion matrix
if final_confusion_matrix is not None:
    print(f"\nüìä **Final Classification Performance:**")
    print(f"   ‚Ä¢ Class 0 accuracy: {final_confusion_matrix[0,0]:.3f}")
    print(f"   ‚Ä¢ Class 1 accuracy: {final_confusion_matrix[1,1]:.3f}")
    print(f"   ‚Ä¢ Mean diagonal accuracy: {np.mean(np.diag(final_confusion_matrix)):.3f}")

print("\n‚úÖ **AI4CellFate training completed successfully!**")


In [None]:
#@title üìä **Latent Space Visualization** {display-mode: "form"}
#@markdown This cell visualizes the 2D latent space learned by AI4CellFate, showing how different cell fates are separated.

import matplotlib.pyplot as plt
import numpy as np

# Predict the latent representations
latent_2d = encoder.predict(x_train_aug) 

# Find extreme points for axis limits
x_min, x_max = latent_2d[:, 0].min() - 0.5, latent_2d[:, 0].max() + 0.5
y_min, y_max = latent_2d[:, 1].min() - 0.5, latent_2d[:, 1].max() + 0.5

# Create the plot
plt.figure(figsize=(8, 6), dpi=300)

# Scatter plot for each class separately with thin gray edges
plt.scatter(latent_2d[y_train_aug == 0][:, 0], latent_2d[y_train_aug == 0][:, 1], 
            color='#648fff', label="Fate 0", alpha=1, edgecolors='k', linewidth=0.5, rasterized=True)  
plt.scatter(latent_2d[y_train_aug == 1][:, 0], latent_2d[y_train_aug == 1][:, 1], 
            color='#dc267f', label="Fate 1", alpha=1, edgecolors='k', linewidth=0.5, rasterized=True)  

# Set axis limits
# plt.xlim(-3, 3)
# plt.ylim(-3, 3)

# Make tick labels much bigger and set to Arial font
plt.tick_params(axis='both', which='major', labelsize=16)
ax = plt.gca()
for label in ax.get_xticklabels() + ax.get_yticklabels():
    label.set_fontname('Arial')

# Increase font size and set Arial font
plt.xlabel("Latent Feature 0 (z0)", fontsize=18, fontname="Arial")
plt.ylabel("Latent Feature 1 (z1)", fontsize=18, fontname="Arial")
plt.title("Latent Space", fontsize=20, fontname="Arial")

# Legend and grid
plt.legend(fontsize=14)
plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)

#plt.savefig("latent_space.eps", format="eps", dpi=600, bbox_inches="tight")
plt.show()

print("üéØ **Latent Space Analysis:**")
print(f"   ‚Ä¢ Total samples plotted: {len(latent_2d)}")
print(f"   ‚Ä¢ Fate 0 samples: {np.sum(y_train_aug == 0)}")
print(f"   ‚Ä¢ Fate 1 samples: {np.sum(y_train_aug == 1)}")
print(f"   ‚Ä¢ Latent Feature 0 range: [{latent_2d[:, 0].min():.3f}, {latent_2d[:, 0].max():.3f}]")
print(f"   ‚Ä¢ Latent Feature 1 range: [{latent_2d[:, 1].min():.3f}, {latent_2d[:, 1].max():.3f}]")


In [None]:
#@title üîç **Latent Feature Interpretation** {display-mode: "form"}
#@markdown This cell generates synthetic images by perturbing individual latent features to understand what each feature controls.

#@markdown ### **Interpretation Parameters**
feature_index = 0 #@param {type:"slider", min:0, max:1, step:1}
#@markdown Choose which latent feature to interpret (0 or 1)

perturbation_min = -2.0 #@param {type:"slider", min:-3.0, max:0.0, step:0.1}
perturbation_max = 2.0 #@param {type:"slider", min:0.0, max:3.0, step:0.1}
#@markdown Set the range for perturbations

num_steps = 5 #@param {type:"slider", min:3, max:10, step:1}
#@markdown Number of perturbation steps to visualize

import matplotlib.pyplot as plt
import numpy as np

print(f"üéØ **Interpreting Latent Feature {feature_index}**")
print(f"   ‚Ä¢ Perturbation range: [{perturbation_min:.1f}, {perturbation_max:.1f}]")
print(f"   ‚Ä¢ Number of steps: {num_steps}")

# Create baseline latent vector
baseline_latent_vector = np.zeros((2, 2), dtype=np.float32)  # Start with a neutral latent vector

# Perturbation range
perturbations = np.linspace(perturbation_min, perturbation_max, num_steps)

# Store the perturbed reconstructions
perturbed_reconstructions = []

print(f"\nüîÑ Generating synthetic images...")
for i, value in enumerate(perturbations):
    # Create a copy of the baseline latent vector
    perturbed_vector = baseline_latent_vector.copy()
    
    # Modify the selected feature
    perturbed_vector[0, feature_index] = value
    
    # Decode the perturbed vector to generate a synthetic image
    synthetic_image = decoder.predict(perturbed_vector, verbose=0)  # Assuming 'decoder' is your trained decoder model
    
    # Store the result
    perturbed_reconstructions.append(synthetic_image[0])  # Assuming decoder outputs (batch_size, height, width, channels)
    print(f"   ‚Ä¢ Step {i+1}/{num_steps}: Perturbation value {value:.2f}")

# Convert list to numpy array for easier handling
perturbed_reconstructions = np.array(perturbed_reconstructions)

# Plot the results
fig, axs = plt.subplots(1, num_steps, figsize=(4*num_steps, 4))

# Handle single subplot case
if num_steps == 1:
    axs = [axs]

vmin = perturbed_reconstructions.min()
vmax = perturbed_reconstructions.max()

for i in range(num_steps):
    im = axs[i].imshow(perturbed_reconstructions[i, :, :, 0], cmap='gray', vmin=0.25, vmax=vmax)
    axs[i].set_title(f'z{feature_index} = {perturbations[i]:.2f}', fontsize=14, fontname='Arial')
    axs[i].axis('off')

plt.suptitle(f'Latent Feature {feature_index} Interpretation', fontsize=16, fontname='Arial', y=1.02)
plt.tight_layout()
#plt.savefig(f"perturbations_feat{feature_index}.eps", format="eps", dpi=300, bbox_inches="tight")
plt.show()

print(f"\n‚úÖ **Feature Interpretation Complete!**")
print(f"   ‚Ä¢ Generated {num_steps} synthetic images")
print(f"   ‚Ä¢ Feature {feature_index} controls: [Observe the visual changes across perturbations]")
print(f"   ‚Ä¢ Use different feature_index values to interpret other latent dimensions")
print(f"\nüí° **Interpretation Tips:**")
print(f"   ‚Ä¢ Smooth changes = well-learned feature representation")
print(f"   ‚Ä¢ Abrupt changes = potential feature entanglement")
print(f"   ‚Ä¢ Try different perturbation ranges to explore feature sensitivity")
