# 🧬 AI4CellFate: Interpretable Cell Fate Prediction

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ComputationalMicroscopy4CellBio/AI4CellFate/blob/main/notebooks/Codeless_AI4CellFate.ipynb)

**AI4CellFate** is a deep learning framework for predicting cell fate from single-frame microscopy images with full interpretability. This notebook provides an interface to train and apply AI4CellFate to your own data.

---

## 📋 **Quick Start Guide**

### **Step 1: Prepare Your Data**
- **Images**: Single-channel images (20×20 pixels recommended)
- **Labels**: Binary classification labels (0 and 1)
- **Format**: NumPy arrays (.npy files)
- **⚠️ IMPORTANT**: Normalise your images to [0,1] range

### **Step 2: Upload Data**
Upload your data files to this Colab session or connect to Google Drive

### **Step 3: Run the Cells**
Execute the cells in order - the code is hidden for simplicity

### **Step 4: Interpret Results**
View latent space visualizations and feature interpretations

---


In [None]:
#@title 🔧 **Setup and Installation** {display-mode: "form"}
#@markdown This cell installs AI4CellFate and required dependencies. Run this first!

import os
import sys
import subprocess

# Clone the AI4CellFate repository
if not os.path.exists('AI4CellFate'):
    !git clone https://github.com/ComputationalMicroscopy4CellBio/AI4CellFate.git
    
# Change to the repository directory
os.chdir('AI4CellFate')
sys.path.append('/content/AI4CellFate')

# Install package-specific requirements
%pip install -q -r requirements.txt

# Import necessary modules
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Import AI4CellFate modules
from src.training.train import train_autoencoder, train_cellfate
from src.models.classifier import mlp_classifier
from src.models import Encoder, Decoder, Discriminator
from src.preprocessing.preprocessing_functions import augment_dataset, augmentations

print("✅ AI4CellFate successfully installed and imported!")
print("📊 TensorFlow version:", tf.__version__)
#print("🔥 GPU available:", "Yes" if tf.config.list_physical_devices('GPU') else "No")


In [None]:
#@title 📁 **Interactive Data Upload, Splitting, Augmentation and Normalisation** {display-mode: "form"}
#@markdown Use this cell to upload data; split if needed (check which boxes apply to your data - if you haven't split your data, click on "no_splits_available"); pick a time frame which you want to predict from; balance/augment train (click on the "has_augmented_train_set" if you already have an augmented train set); and ensure normalisation to [0,1].

from google.colab import files
from sklearn.model_selection import train_test_split
import numpy as np
import sys

# -----------------------------
# 1) Split availability options
# -----------------------------
print("🔧 Data availability options")
has_train_split = True  #@param {type:"boolean"}
has_val_split = True   #@param {type:"boolean"}
has_test_split = True   #@param {type:"boolean"}
no_splits_available = False  #@param {type:"boolean"}

if no_splits_available:
    has_train_split = False
    has_val_split = False
    has_test_split = False

# -----------------------------
# Augmentation settings (set before data upload)
# -----------------------------
has_augmented_train_set = False  #@param {type:"boolean"}
augment_times = 5  #@param {type:"integer"}
# -----------------------------
# Upload method
# -----------------------------
upload_method = "upload" #@param ["upload", "drive"]

#@markdown **Fill these paths in case you chose the "drive" option for data uploading:**

train_images_path = "/content/drive/MyDrive/final_split_ai4cellfate_data/augmented_x_train.npy" #@param {type:"string"}
train_labels_path = "/content/drive/MyDrive/final_split_ai4cellfate_data/augmented_y_train.npy" #@param {type:"string"}
val_images_path = "/content/drive/MyDrive/final_split_ai4cellfate_data/x_val.npy" #@param {type:"string"}
val_labels_path = "/content/drive/MyDrive/final_split_ai4cellfate_data/y_val.npy" #@param {type:"string"}
test_images_path = "/content/drive/MyDrive/final_split_ai4cellfate_data/x_test.npy" #@param {type:"string"}
test_labels_path = "/content/drive/MyDrive/final_split_ai4cellfate_data/y_test.npy" #@param {type:"string"}
combined_images_path = "" #@param {type:"string"}
combined_labels_path = "" #@param {type:"string"}

# -----------------------------
# 2) Load data depending on splits
# -----------------------------
X_train_raw = y_train_raw = X_val_raw = y_val_raw = X_test_raw = y_test_raw = None
all_uploaded_files = []

if not has_train_split and not has_val_split and not has_test_split:
    # Single combined dataset
    if upload_method == "upload":
        print("📤 Upload all your files (images and labels). You can click 'Choose Files' multiple times to upload from different locations.")
        print("Expected file naming: files should contain 'image'/'x_' for images and 'label'/'y_' for labels")
        
        # Allow multiple uploads
        upload_done = False
        while not upload_done:
            try:
                uploaded = files.upload()
                all_uploaded_files.extend(list(uploaded.keys()))
                print(f"📥 Files uploaded so far: {all_uploaded_files}")
                
                # Check if we have what we need
                images_file = next((k for k in all_uploaded_files if ('image' in k.lower() or 'x_' in k.lower())), None)
                labels_file = next((k for k in all_uploaded_files if ('label' in k.lower() or 'y_' in k.lower())), None)
                
                if images_file and labels_file:
                    user_input = input("✅ Found images and labels files. Type 'done' to proceed or press Enter to upload more files: ")
                    if user_input.lower() == 'done':
                        upload_done = True
                else:
                    user_input = input("⏳ Still need images and labels files. Press Enter to upload more files or type 'done' to proceed anyway: ")
                    if user_input.lower() == 'done':
                        upload_done = True
            except:
                upload_done = True
        
        images_file = next((k for k in all_uploaded_files if ('image' in k.lower() or 'x_' in k.lower())), None)
        labels_file = next((k for k in all_uploaded_files if ('label' in k.lower() or 'y_' in k.lower())), None)
        if images_file is None or labels_file is None:
            print("❌ Please include both images and labels (files named x_[...] and y_[...]).")
            sys.exit()
        X_all = np.load(images_file)
        y_all = np.load(labels_file)
    else:
        from google.colab import drive
        drive.mount('/content/drive')
        X_all = np.load(combined_images_path)
        y_all = np.load(combined_labels_path)

    # Ask for split ratios if needed
    print("\n🔀 Performing stratified 60/20/20 split (train/val/test)")
    X_temp, X_test_raw, y_temp, y_test_raw = train_test_split(
        X_all, y_all, test_size=0.2, random_state=42, stratify=y_all
    )
    X_train_raw, X_val_raw, y_train_raw, y_val_raw = train_test_split(
        X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp
    )
else:
    if upload_method == "upload":
        print("📤 Upload all your files (train/val/test images and labels). You can click 'Choose Files' multiple times to upload from different locations.")
        print("Expected file naming: files should contain 'train'/'val'/'test' and 'image'/'x_' for images, 'label'/'y_' for labels")
        
        # Allow multiple uploads
        upload_done = False
        while not upload_done:
            try:
                uploaded = files.upload()
                all_uploaded_files.extend(list(uploaded.keys()))
                print(f"📥 Files uploaded so far: {all_uploaded_files}")
                
                # Check what we have
                train_x = next((k for k in all_uploaded_files if ('train' in k.lower() and ('image' in k.lower() or 'x_' in k.lower()))), None)
                train_y = next((k for k in all_uploaded_files if ('train' in k.lower() and ('label' in k.lower() or 'y_' in k.lower()))), None)
                val_x = next((k for k in all_uploaded_files if ('val' in k.lower() and ('image' in k.lower() or 'x_' in k.lower()))), None)
                val_y = next((k for k in all_uploaded_files if ('val' in k.lower() and ('label' in k.lower() or 'y_' in k.lower()))), None)
                test_x = next((k for k in all_uploaded_files if ('test' in k.lower() and ('image' in k.lower() or 'x_' in k.lower()))), None)
                test_y = next((k for k in all_uploaded_files if ('test' in k.lower() and ('label' in k.lower() or 'y_' in k.lower()))), None)
                
                found_splits = []
                if has_train_split and train_x and train_y: found_splits.append("train")
                if has_val_split and val_x and val_y: found_splits.append("val")
                if has_test_split and test_x and test_y: found_splits.append("test")
                
                needed_splits = []
                if has_train_split: needed_splits.append("train")
                if has_val_split: needed_splits.append("val")
                if has_test_split: needed_splits.append("test")
                
                print(f"✅ Found splits: {found_splits}")
                print(f"📋 Need splits: {needed_splits}")
                
                if set(found_splits) >= set(needed_splits):
                    user_input = input("✅ All required files found. Type 'done' to proceed or press Enter to upload more files: ")
                    if user_input.lower() == 'done':
                        upload_done = True
                else:
                    missing = set(needed_splits) - set(found_splits)
                    user_input = input(f"⏳ Still missing: {list(missing)}. Press Enter to upload more files or type 'done' to proceed anyway: ")
                    if user_input.lower() == 'done':
                        upload_done = True
            except:
                upload_done = True
        
        # Load the files
        if has_train_split:
            xfile = next((k for k in all_uploaded_files if ('train' in k.lower() and ('image' in k.lower() or 'x_' in k.lower()))), None)
            yfile = next((k for k in all_uploaded_files if ('train' in k.lower() and ('label' in k.lower() or 'y_' in k.lower()))), None)
            if xfile and yfile:
                X_train_raw = np.load(xfile)
                y_train_raw = np.load(yfile)
        
        if has_val_split:
            xfile = next((k for k in all_uploaded_files if ('val' in k.lower() and ('image' in k.lower() or 'x_' in k.lower()))), None)
            yfile = next((k for k in all_uploaded_files if ('val' in k.lower() and ('label' in k.lower() or 'y_' in k.lower()))), None)
            if xfile and yfile:
                X_val_raw = np.load(xfile)
                y_val_raw = np.load(yfile)
        
        if has_test_split:
            xfile = next((k for k in all_uploaded_files if ('test' in k.lower() and ('image' in k.lower() or 'x_' in k.lower()))), None)
            yfile = next((k for k in all_uploaded_files if ('test' in k.lower() and ('label' in k.lower() or 'y_' in k.lower()))), None)
            if xfile and yfile:
                X_test_raw = np.load(xfile)
                y_test_raw = np.load(yfile)
    
    else:
        from google.colab import drive
        drive.mount('/content/drive')
        
        if has_train_split:
            X_train_raw = np.load(train_images_path)
            y_train_raw = np.load(train_labels_path)
        
        if has_val_split:
            X_val_raw = np.load(val_images_path)
            y_val_raw = np.load(val_labels_path)
        
        if has_test_split:
            X_test_raw = np.load(test_images_path)
            y_test_raw = np.load(test_labels_path)

    # If val not provided but train/test were, split train to create val
    if has_train_split and not has_val_split and X_train_raw is not None and y_train_raw is not None:
        print("\n🔀 Creating validation set (20%) from training data with stratification")
        X_train_raw, X_val_raw, y_train_raw, y_val_raw = train_test_split(
            X_train_raw, y_train_raw, test_size=0.2, random_state=42, stratify=y_train_raw
        )

# -----------------------------
# 3) Time-frame selection for time-lapse
# -----------------------------
frame_index = 0  #@param {type:"integer"}

def select_frame(X, frame_idx):
    if X is None:
        return None
    if len(X.shape) == 4:
        # (n_samples, time, H, W)
        t = X.shape[1]
        idx = max(0, min(frame_idx, t-1))
        return X[:, idx, :, :]
    return X

X_train = select_frame(X_train_raw, frame_index)
X_val = select_frame(X_val_raw, frame_index)
X_test = select_frame(X_test_raw, frame_index)

if not has_augmented_train_set:
    print("\n🔄 Balancing and augmenting TRAIN set...")
    X_train_aug, y_train_aug = augment_dataset(
        X_train, y_train_raw, augmentations, augment_times=augment_times, seed=42
    )
else:
    X_train_aug, y_train_aug = X_train, y_train_raw

# -----------------------------
# 5) Summary and checks
# -----------------------------
print("\n📊 **Dataset Summary (post-processing):**")
print(f"   • Train: {None if X_train_aug is None else X_train_aug.shape}")
print(f"   • Val:   {None if X_val is None else X_val.shape}")
print(f"   • Test:  {None if X_test is None else X_test.shape}")

# Balance check
if y_train_aug is not None:
    c0, c1 = np.sum(y_train_aug == 0), np.sum(y_train_aug == 1)
    print(f"\n🏷️  Train labels — Class 0: {c0}, Class 1: {c1}")
    print("   • Balanced:" , "Yes" if c0 == c1 else "No")

# Normalisation check across combined datasets (augmented train + val + test)
all_sets = []
if X_train_aug is not None: all_sets.append(X_train_aug)
if X_val is not None: all_sets.append(X_val)
if X_test is not None: all_sets.append(X_test)

if len(all_sets) > 0:
    combined_max = max([arr.max() for arr in all_sets])
    combined_min = min([arr.min() for arr in all_sets])
    print(f"\n🔎 Normalisation check — global range before: [{combined_min:.3f}, {combined_max:.3f}]")
    if combined_max > 1.0 or combined_min < 0.0:
        print("⚠️  Images not in [0,1]. Applying global normalisation across splits...")
        global_max = combined_max if combined_max > 0 else 1.0
        if X_train_aug is not None: X_train_aug = X_train_aug / global_max
        if X_val is not None: X_val = X_val / global_max
        if X_test is not None: X_test = X_test / global_max
        combined_max = max([arr.max() for arr in [a for a in [X_train_aug, X_val, X_test] if a is not None]])
        combined_min = min([arr.min() for arr in [a for a in [X_train_aug, X_val, X_test] if a is not None]])
        print(f"   • New global range: [{combined_min:.3f}, {combined_max:.3f}]")
    else:
        print("✅ Images already in [0,1].")

# -----------------------------
# 6) Save final variables with correct names for training
# -----------------------------
# Ensure all final variables are correctly assigned for all cases
x_train_aug = X_train_aug if X_train_aug is not None else None
y_train_aug = y_train_aug if y_train_aug is not None else None
x_val = X_val if X_val is not None else None
y_val = y_val_raw if y_val_raw is not None else None
x_test = X_test if X_test is not None else None
y_test = y_test_raw if y_test_raw is not None else None

print("\n✅ Data upload, splitting, augmentation and normalisation completed!")

# -----------------------------
# 7) Final variable summary
# -----------------------------
print("\n📋 **Final Variables Saved:**")
if x_train_aug is not None:
    print(f"   • Variable saved as 'x_train_aug': {x_train_aug.shape}")
else:
    print(f"   • Variable saved as 'x_train_aug': None")
    
if y_train_aug is not None:
    print(f"   • Variable saved as 'y_train_aug': {y_train_aug.shape}")
else:
    print(f"   • Variable saved as 'y_train_aug': None")
    
if x_val is not None:
    print(f"   • Variable saved as 'x_val': {x_val.shape}")
else:
    print(f"   • Variable saved as 'x_val': None")
    
if y_val is not None:
    print(f"   • Variable saved as 'y_val': {y_val.shape}")
else:
    print(f"   • Variable saved as 'y_val': None")
    
if x_test is not None:
    print(f"   • Variable saved as 'x_test': {x_test.shape}")
else:
    print(f"   • Variable saved as 'x_test': None")
    
if y_test is not None:
    print(f"   • Variable saved as 'y_test': {y_test.shape}")
else:
    print(f"   • Variable saved as 'y_test': None")

print("\n🎯 **Ready for AI4CellFate training!**")


---

## 📐 **Image Size Requirements & Model Adaptation**

### **Default Configuration**
AI4CellFate is designed for **20x20 pixel grayscale images**. If your data matches this size, you can proceed directly to the next cell.

### **For Different Image Sizes**
If your images are larger (e.g., 64x64, 128x128), you'll need to modify the model architecture:

#### **Step-by-Step Guide:**

1. **Navigate to Model Files:**
   ```
   AI4CellFate/
   └── src/
       └── models/
           ├── encoder.py    ← Modify this
           └── decoder.py    ← Modify this
   ```

2. **Update Encoder Architecture (`src/models/encoder.py`):**
   - Locate the `Conv2D` layers in the `__init__` method
   - **Add more layers** for larger images:
     - For 64x64: Add 1-2 additional Conv2D layers
     - For 128x128: Add 2-3 additional Conv2D layers
   - **Pattern**: Each Conv2D layer typically halves the spatial dimensions
   - **Example**: 128x128 → 64x64 → 32x32 → 16x16 → 8x8 → 4x4 → flatten

3. **Update Decoder Architecture (`src/models/decoder.py`):**
   - Mirror the encoder changes in reverse
   - Adjust the `Dense` layer input size to match encoder output
   - Use `Conv2DTranspose` layers to upscale back to original size

4. **Key Parameters to Adjust:**
   - `filters`: Number of feature maps (typically 32, 64, 128, 256)
   - `kernel_size`: Usually (3,3) or (4,4)
   - `strides`: Usually (2,2) for downsampling/upsampling
   - `padding`: Usually 'same' to maintain dimensions

#### **Quick Tips:**
- **Maintain symmetry**: Encoder downsampling should match decoder upsampling
- **Test incrementally**: Start with one additional layer pair
- **Monitor memory**: Larger images require more GPU memory
- **Adjust batch size**: You may need to reduce `batch_size` for larger images

#### **Example Modification:**
For 64x64 images, add one more Conv2D layer in encoder:
```python
# In encoder.py, add after existing Conv2D layers:
x = Conv2D(64, (3, 3), strides=(2, 2), padding='same')(x)  # 64→32
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
```

And corresponding Conv2DTranspose in decoder:
```python
# In decoder.py, add before existing Conv2DTranspose layers:
x = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(x)  # 32→64
```

---


In [None]:
#@title ⚙️ **Model Configuration** {display-mode: "form"}
#@markdown Configure AI4CellFate training parameters

# Training parameters
latent_dim = 2 #@param {type:"integer"}
batch_size = 30 #@param {type:"integer"}
stage1_epochs = 35 #@param {type:"integer"}
stage2_epochs = 100 #@param {type:"integer"}
learning_rate = 0.001 #@param {type:"number"}
random_seed = 42 #@param {type:"integer"}

# Advanced parameters
gaussian_noise_std = 0.003 #@param {type:"number"}
lambda_recon_stage1 = 5 #@param {type:"integer"}
lambda_adv_stage1 = 1 #@param {type:"integer"}
lambda_recon_stage2 = 6 #@param {type:"integer"}
lambda_adv_stage2 = 4 #@param {type:"integer"}
lambda_cov = 1 #@param {type:"number"}
lambda_contra = 8 #@param {type:"integer"}

# Configuration dictionaries
config_stage1 = {
    'batch_size': batch_size,
    'epochs': stage1_epochs,
    'learning_rate': learning_rate,
    'seed': random_seed,
    'latent_dim': latent_dim,
    'GaussianNoise_std': gaussian_noise_std,
    'lambda_recon': lambda_recon_stage1,
    'lambda_adv': lambda_adv_stage1,
}

config_stage2 = {
    'batch_size': batch_size,
    'epochs': stage2_epochs,
    'learning_rate': learning_rate,
    'seed': random_seed,
    'latent_dim': latent_dim,
    'GaussianNoise_std': gaussian_noise_std,
    'lambda_recon': lambda_recon_stage2,
    'lambda_adv': lambda_adv_stage2,
    'lambda_cov': lambda_cov,
    'lambda_contra': lambda_contra,
}

print(f"⚙️ **Model Configuration:**")
print(f"   • Latent dimensions: {latent_dim}")
print(f"   • Batch size: {batch_size}")
print(f"   • Stage 1 epochs: {stage1_epochs}")
print(f"   • Stage 2 epochs: {stage2_epochs}")
print(f"   • Learning rate: {learning_rate}")
print(f"   • Random seed: {random_seed}")
print(f"\n🔧 **Advanced Parameters:**")
print(f"   • Gaussian noise std: {gaussian_noise_std}")
print(f"   • Lambda reconstruction (S1/S2): {lambda_recon_stage1}/{lambda_recon_stage2}")
print(f"   • Lambda adversarial (S1/S2): {lambda_adv_stage1}/{lambda_adv_stage2}")
print(f"   • Lambda covariance: {lambda_cov}")
print(f"   • Lambda contrastive: {lambda_contra}")

print("\n✅ Model configuration completed!")


In [None]:
#@title 🚀 **AI4CellFate Model Training (Stage 1 + Stage 2)** {display-mode: "form"}
#@markdown Train the complete AI4CellFate model in two stages: (1) Adversarial Autoencoder, (2) Latent Space Engineering

print("🚀 **AI4CellFate Model Training**")
print("   Training will proceed in two stages:")
print("   • Stage 1: Adversarial Autoencoder (reconstruction + Gaussian latent space)")
print("   • Stage 2: Latent Space Engineering (+ covariance + contrastive losses)")
print()

# =============================================================================
# STAGE 1: Adversarial Autoencoder Training
# =============================================================================
print("🔄 **Stage 1: Training Adversarial Autoencoder**")
print("   Learning basic image reconstruction and Gaussian latent space...")

stage1_results = train_autoencoder(
    config_stage1, 
    x_train_aug, 
    x_val
)

# Extract trained models
encoder = stage1_results['encoder']
decoder = stage1_results['decoder']
discriminator = stage1_results['discriminator']

print("\n✅ **Stage 1 Completed!**")
print(f"   • Final reconstruction loss: {stage1_results['recon_loss'][-1]:.4f}")
print(f"   • Final adversarial loss: {stage1_results['adv_loss'][-1]:.4f}")

# =============================================================================
# STAGE 2: AI4CellFate Training (Latent Space Engineering)
# =============================================================================
print("\n🔄 **Stage 2: AI4CellFate Training (Latent Space Engineering)**")
print("   Adding covariance and contrastive losses for feature disentanglement...")

stage2_results = train_cellfate(
    config_stage2,
    encoder,
    decoder, 
    discriminator,
    x_train_aug,
    y_train_aug,
    x_val,
    y_val,
    x_test,
    y_test
)

# Update models with final trained versions
encoder = stage2_results['encoder']
decoder = stage2_results['decoder']
discriminator = stage2_results['discriminator']
final_confusion_matrix = stage2_results['confusion_matrix']

print("\n🎉 **Stage 2 Completed!**")
print(f"   • Final reconstruction loss: {stage2_results['recon_loss'][-1]:.4f}")
print(f"   • Final adversarial loss: {stage2_results['adv_loss'][-1]:.4f}")
print(f"   • Final covariance loss: {stage2_results['cov_loss'][-1]:.4f}")
print(f"   • Final contrastive loss: {stage2_results['contra_loss'][-1]:.4f}")
print(f"   • Training stopped at epochs: {stage2_results['good_conditions_stop']}")

# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print("\n🏆 **Final Training Results:**")
print("   • Models trained successfully on both stages")
print("   • Latent space engineered for optimal feature disentanglement")
print("   • Ready for latent space visualization and interpretation")

# Display final confusion matrix
if final_confusion_matrix is not None:
    print(f"\n📊 **Final Classification Performance:**")
    print(f"   • Class 0 accuracy: {final_confusion_matrix[0,0]:.3f}")
    print(f"   • Class 1 accuracy: {final_confusion_matrix[1,1]:.3f}")
    print(f"   • Mean diagonal accuracy: {np.mean(np.diag(final_confusion_matrix)):.3f}")

print("\n✅ **AI4CellFate training completed successfully!**")


---

## 🎯 **Understanding Your Results**

### **Latent Space Visualization**
- **Good separation**: Classes should form distinct clusters
- **Low correlation**: Features should be independent (correlation < 0.3)
- **Centroid distance**: Higher values indicate better class separation

### **Feature Interpretation**
- **Feature 0 & 1**: Each typically controls different cellular properties
- **Common patterns**: Size vs. intensity, morphology vs. signal activity
- **Biological relevance**: Features should relate to known cell fate determinants

### **Performance Metrics**
- **Accuracy**: Overall classification performance
- **Precision**: Ability to avoid false positives
- **Recall**: Ability to identify all positive cases
- **F1-Score**: Balanced measure combining precision and recall

---

## 🔧 **Troubleshooting**

### **Common Issues**
1. **Poor separation**: Try adjusting lambda values or increasing training epochs
2. **High correlation**: Increase covariance loss weight (lambda_cov)
3. **Low accuracy**: Check data quality, normalization, or try different latent dimensions
4. **Overfitting**: Reduce model complexity or increase regularization

### **Parameter Tuning**
- **Latent dimensions**: Start with 2-3, increase if performance plateaus
- **Lambda values**: Balance reconstruction, adversarial, covariance, and contrastive losses
- **Training epochs**: Monitor convergence, stop when losses stabilize
- **Batch size**: Adjust based on dataset size and available memory

---

## 📖 **Citation**

If you use AI4CellFate in your research, please cite:

```
[Your paper citation here]
```

---

**Developed by [Your Name/Lab]** | **📧 Contact: [your.email@institution.edu]** | **🔗 GitHub: [repository_link]**
