# 🧬 AI4CellFate: Interpretable Cell Fate Prediction

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ComputationalMicroscopy4CellBio/AI4CellFate/blob/main/notebooks/Codeless_AI4CellFate.ipynb)

**AI4CellFate** is a deep learning framework for predicting cell fate from single-frame microscopy images with full interpretability. This notebook provides an interface to train and apply AI4CellFate to your own data.

---

## 📋 **Quick Start Guide**

### **Step 1: Prepare Your Data**
- **Images**: Single-channel images (20×20 pixels recommended)
- **Labels**: Binary classification labels (0 and 1)
- **Format**: NumPy arrays (.npy files)
- **⚠️ IMPORTANT**: Normalise your images to [0,1] range

### **Step 2: Upload Data**
Upload your data files to this Colab session or connect to Google Drive

### **Step 3: Run the Cells**
Execute the cells in order - the code is hidden for simplicity

### **Step 4: Interpret Results**
View latent space visualizations and feature interpretations

---


In [None]:
#@title 🔧 **Setup and Installation** {display-mode: "form"}
#@markdown This cell installs AI4CellFate and required dependencies. Run this first!

import os
import sys
import subprocess

# Clone the AI4CellFate repository
if not os.path.exists('AI4CellFate'):
    !git clone https://github.com/ComputationalMicroscopy4CellBio/AI4CellFate.git
    
# Change to the repository directory
os.chdir('AI4CellFate')
sys.path.append('/content/AI4CellFate')

# Install package-specific requirements
%pip install -q tensorflow opencv-python scikit-learn matplotlib seaborn scipy

# Import necessary modules
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Import AI4CellFate modules
from src.training.train import train_autoencoder, train_cellfate
from src.models.classifier import mlp_classifier
from src.models import Encoder, Decoder, Discriminator
from src.preprocessing.preprocessing_functions import augment_dataset, augmentations

print("✅ AI4CellFate successfully installed and imported!")
print("📊 TensorFlow version:", tf.__version__)
#print("🔥 GPU available:", "Yes" if tf.config.list_physical_devices('GPU') else "No")


In [None]:
#@title 📁 **Data Upload and Preparation** {display-mode: "form"}

#@markdown ## 📋 **Data Preparation Guidelines**
#@markdown Before uploading, ensure your data meets these requirements:
#@markdown
#@markdown ### **✅ Required Format:**
#@markdown - **Images**: NumPy arrays (.npy) with shape `[n_samples, height, width]`
#@markdown - **Labels**: NumPy arrays (.npy) with binary labels (0 and 1)
#@markdown - **Normalisation**: Images must be normalised to [0, 1] range
#@markdown - **Splits**: Provide separate train, validation, and test sets
#@markdown - **Balance**: Training set should be balanced (equal samples per class)
#@markdown - **Augmentation**: Training set should preferably be augmented (5x recommended)
#@markdown
#@markdown ### **📂 Expected Files:**
#@markdown - `x_train.npy` / `x_train_aug.npy` (training images)
#@markdown - `y_train.npy` / `y_train_aug.npy` (training labels)  
#@markdown - `x_val.npy` (validation images)
#@markdown - `y_val.npy` (validation labels)
#@markdown - `x_test.npy` (test images)
#@markdown - `y_test.npy` (test labels)
#@markdown
#@markdown ### **🔧 Preprocessing Steps (Do Before Upload):**
#@markdown 1. **Normalise** images: `images = images / images.max()` or `images = (images - images.min()) / (images.max() - images.min())`
#@markdown 2. **Split** data: 60% train, 20% validation, 20% test (stratified)
#@markdown 3. **Balance** training set: Equal number of samples per class
#@markdown 4. **Augment** training set: Apply rotations, flips (5x augmentation recommended)
#@markdown 5. **Select frame**: If time-lapse data, extract single frame per cell `[n_samples, height, width]`

from google.colab import drive
import numpy as np

#@markdown ---
#@markdown ## 📤 **Upload Method**
upload_method = "drive" #@param ["drive", "upload"]

if upload_method == "drive":
    # Mount Google Drive
    drive.mount('/content/drive')
    
    #@markdown ### **📁 Google Drive File Paths**
    #@markdown Specify the full paths to your prepared data files:
    
    x_train_path = "/content/drive/MyDrive/AI4CellFate_data/x_train_aug.npy" #@param {type:"string"}
    y_train_path = "/content/drive/MyDrive/AI4CellFate_data/y_train_aug.npy" #@param {type:"string"}
    x_val_path = "/content/drive/MyDrive/AI4CellFate_data/x_val.npy" #@param {type:"string"}
    y_val_path = "/content/drive/MyDrive/AI4CellFate_data/y_val.npy" #@param {type:"string"}
    x_test_path = "/content/drive/MyDrive/AI4CellFate_data/x_test.npy" #@param {type:"string"}
    y_test_path = "/content/drive/MyDrive/AI4CellFate_data/y_test.npy" #@param {type:"string"}
    
    # Load data from Drive
    print("📁 Loading data from Google Drive...")
    try:
        x_train_aug = np.load(x_train_path)
        y_train_aug = np.load(y_train_path)
        x_val = np.load(x_val_path)
        y_val = np.load(y_val_path)
        x_test = np.load(x_test_path)
        y_test = np.load(y_test_path)
        print("✅ All files loaded successfully!")
    except Exception as e:
        print(f"❌ Error loading files: {e}")
        print("Please check your file paths and ensure files exist.")
        raise

else:
    # Upload files directly
    from google.colab import files
    
    print("📤 Upload your prepared data files (.npy format)")
    print("Expected files: x_train_aug.npy, y_train_aug.npy, x_val.npy, y_val.npy, x_test.npy, y_test.npy")
    
    uploaded = files.upload()
    
    # Load uploaded files
    try:
        # Find and load files based on naming patterns
        train_x_file = next((f for f in uploaded.keys() if 'train' in f.lower() and ('x_' in f.lower() or 'image' in f.lower())), None)
        train_y_file = next((f for f in uploaded.keys() if 'train' in f.lower() and ('y_' in f.lower() or 'label' in f.lower())), None)
        val_x_file = next((f for f in uploaded.keys() if 'val' in f.lower() and ('x_' in f.lower() or 'image' in f.lower())), None)
        val_y_file = next((f for f in uploaded.keys() if 'val' in f.lower() and ('y_' in f.lower() or 'label' in f.lower())), None)
        test_x_file = next((f for f in uploaded.keys() if 'test' in f.lower() and ('x_' in f.lower() or 'image' in f.lower())), None)
        test_y_file = next((f for f in uploaded.keys() if 'test' in f.lower() and ('y_' in f.lower() or 'label' in f.lower())), None)
        
        if not all([train_x_file, train_y_file, val_x_file, val_y_file, test_x_file, test_y_file]):
            missing = []
            if not train_x_file: missing.append("training images")
            if not train_y_file: missing.append("training labels")
            if not val_x_file: missing.append("validation images")
            if not val_y_file: missing.append("validation labels")
            if not test_x_file: missing.append("test images")
            if not test_y_file: missing.append("test labels")
            raise ValueError(f"Missing required files: {', '.join(missing)}")
        
        x_train_aug = np.load(train_x_file)
        y_train_aug = np.load(train_y_file)
        x_val = np.load(val_x_file)
        y_val = np.load(val_y_file)
        x_test = np.load(test_x_file)
        y_test = np.load(test_y_file)
        
        print("✅ All files loaded successfully!")
        
    except Exception as e:
        print(f"❌ Error loading files: {e}")
        raise

# Validate data format and requirements
print("\n🔍 **Data Validation:**")

# Check shapes
print(f"📊 **Data Shapes:**")
print(f"   • Training: {x_train_aug.shape} images, {y_train_aug.shape} labels")
print(f"   • Validation: {x_val.shape} images, {y_val.shape} labels")
print(f"   • Test: {x_test.shape} images, {y_test.shape} labels")

# Check if images are 3D (samples, height, width)
if len(x_train_aug.shape) != 3:
    print(f"⚠️  WARNING: Training images have shape {x_train_aug.shape}. Expected 3D: [samples, height, width]")
if len(x_val.shape) != 3:
    print(f"⚠️  WARNING: Validation images have shape {x_val.shape}. Expected 3D: [samples, height, width]")
if len(x_test.shape) != 3:
    print(f"⚠️  WARNING: Test images have shape {x_test.shape}. Expected 3D: [samples, height, width]")

# Check normalisation
train_min, train_max = x_train_aug.min(), x_train_aug.max()
val_min, val_max = x_val.min(), x_val.max()
test_min, test_max = x_test.min(), x_test.max()

print(f"\n📏 **Normalisation Check:**")
print(f"   • Training: [{train_min:.3f}, {train_max:.3f}]")
print(f"   • Validation: [{val_min:.3f}, {val_max:.3f}]")
print(f"   • Test: [{test_min:.3f}, {test_max:.3f}]")

if not (0 <= train_min and train_max <= 1 and 0 <= val_min and val_max <= 1 and 0 <= test_min and test_max <= 1):
    print("❌ WARNING: Images not properly normalised to [0,1] range!")
    print("Please normalise your images before proceeding.")
else:
    print("✅ Images properly normalised to [0,1] range")

# Check class balance in training set
unique_train, counts_train = np.unique(y_train_aug, return_counts=True)
print(f"\n⚖️  **Training Set Balance:**")
for class_label, count in zip(unique_train, counts_train):
    print(f"   • Class {class_label}: {count} samples")

if len(counts_train) == 2 and abs(counts_train[0] - counts_train[1]) / max(counts_train) < 0.1:
    print("✅ Training set is balanced")
else:
    print("⚠️  WARNING: Training set appears unbalanced. Consider balancing before training.")

# Check binary labels
all_labels = np.concatenate([y_train_aug, y_val, y_test])
unique_labels = np.unique(all_labels)
if not np.array_equal(unique_labels, [0, 1]):
    print(f"⚠️  WARNING: Labels contain values {unique_labels}. Expected binary labels [0, 1]")
else:
    print("✅ Binary labels [0, 1] detected")

print(f"\n🎯 **Data loaded and validated! Ready for training.**")
print(f"\n📋 **Variables available for training:**")
print(f"   • x_train_aug: {x_train_aug.shape}")
print(f"   • y_train_aug: {y_train_aug.shape}")
print(f"   • x_val: {x_val.shape}")
print(f"   • y_val: {y_val.shape}")
print(f"   • x_test: {x_test.shape}")
print(f"   • y_test: {y_test.shape}")


---

## 📐 **Image Size Requirements & Model Adaptation**

### **Default Configuration**
AI4CellFate is designed for **20x20 pixel grayscale images**. If your data matches this size, you can proceed directly to the next cell.

### **For Different Image Sizes**
If your images are larger (e.g., 64x64, 128x128), you'll need to modify the model architecture:

#### **Step-by-Step Guide:**

1. **Navigate to Model Files:**
   ```
   AI4CellFate/
   └── src/
       └── models/
           ├── encoder.py    ← Modify this
           └── decoder.py    ← Modify this
   ```

2. **Update Encoder Architecture (`src/models/encoder.py`):**
   - Locate the `Conv2D` layers in the `__init__` method
   - **Add more layers** for larger images:
     - For 64x64: Add 1-2 additional Conv2D layers
     - For 128x128: Add 2-3 additional Conv2D layers
   - **Pattern**: Each Conv2D layer typically halves the spatial dimensions
   - **Example**: 128x128 → 64x64 → 32x32 → 16x16 → 8x8 → 4x4 → flatten

3. **Update Decoder Architecture (`src/models/decoder.py`):**
   - Mirror the encoder changes in reverse
   - Adjust the `Dense` layer input size to match encoder output
   - Use `Conv2DTranspose` layers to upscale back to original size

4. **Key Parameters to Adjust:**
   - `filters`: Number of feature maps (typically 32, 64, 128, 256)
   - `kernel_size`: Usually (3,3) or (4,4)
   - `strides`: Usually (2,2) for downsampling/upsampling
   - `padding`: Usually 'same' to maintain dimensions

#### **Quick Tips:**
- **Maintain symmetry**: Encoder downsampling should match decoder upsampling
- **Test incrementally**: Start with one additional layer pair
- **Monitor memory**: Larger images require more GPU memory
- **Adjust batch size**: You may need to reduce `batch_size` for larger images

#### **Example Modification:**
For 64x64 images, add one more Conv2D layer in encoder:
```python
# In encoder.py, add after existing Conv2D layers:
x = Conv2D(64, (3, 3), strides=(2, 2), padding='same')(x)  # 64→32
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
```

And corresponding Conv2DTranspose in decoder:
```python
# In decoder.py, add before existing Conv2DTranspose layers:
x = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(x)  # 32→64
```

---


In [None]:
#@title ⚙️ **Model Configuration** {display-mode: "form"}
#@markdown Configure AI4CellFate training parameters

# Training parameters
latent_dim = 2 #@param {type:"integer"}
batch_size = 30 #@param {type:"integer"}
stage1_epochs = 35 #@param {type:"integer"}
stage2_epochs = 100 #@param {type:"integer"}
learning_rate = 0.001 #@param {type:"number"}
random_seed = 42 #@param {type:"integer"}

# Advanced parameters
gaussian_noise_std = 0.003 #@param {type:"number"}
lambda_recon_stage1 = 5 #@param {type:"integer"}
lambda_adv_stage1 = 1 #@param {type:"integer"}
lambda_recon_stage2 = 6 #@param {type:"integer"}
lambda_adv_stage2 = 4 #@param {type:"integer"}
lambda_cov = 1 #@param {type:"number"}
lambda_contra = 8 #@param {type:"integer"}

# Configuration dictionaries
config_stage1 = {
    'batch_size': batch_size,
    'epochs': stage1_epochs,
    'learning_rate': learning_rate,
    'seed': random_seed,
    'latent_dim': latent_dim,
    'GaussianNoise_std': gaussian_noise_std,
    'lambda_recon': lambda_recon_stage1,
    'lambda_adv': lambda_adv_stage1,
}

config_stage2 = {
    'batch_size': batch_size,
    'epochs': stage2_epochs,
    'learning_rate': learning_rate,
    'seed': random_seed,
    'latent_dim': latent_dim,
    'GaussianNoise_std': gaussian_noise_std,
    'lambda_recon': lambda_recon_stage2,
    'lambda_adv': lambda_adv_stage2,
    'lambda_cov': lambda_cov,
    'lambda_contra': lambda_contra,
}

print(f"⚙️ **Model Configuration:**")
print(f"   • Latent dimensions: {latent_dim}")
print(f"   • Batch size: {batch_size}")
print(f"   • Stage 1 epochs: {stage1_epochs}")
print(f"   • Stage 2 epochs: {stage2_epochs}")
print(f"   • Learning rate: {learning_rate}")
print(f"   • Random seed: {random_seed}")
print(f"\n🔧 **Advanced Parameters:**")
print(f"   • Gaussian noise std: {gaussian_noise_std}")
print(f"   • Lambda reconstruction (S1/S2): {lambda_recon_stage1}/{lambda_recon_stage2}")
print(f"   • Lambda adversarial (S1/S2): {lambda_adv_stage1}/{lambda_adv_stage2}")
print(f"   • Lambda covariance: {lambda_cov}")
print(f"   • Lambda contrastive: {lambda_contra}")

print("\n✅ Model configuration completed!")


In [None]:
#@title 🚀 **AI4CellFate Model Training (Stage 1 + Stage 2)** {display-mode: "form"}
#@markdown Train the complete AI4CellFate model in two stages: (1) Adversarial Autoencoder, (2) Latent Space Engineering

print("🚀 **AI4CellFate Model Training**")
print("   Training will proceed in two stages:")
print("   • Stage 1: Adversarial Autoencoder (reconstruction + Gaussian latent space)")
print("   • Stage 2: Latent Space Engineering (+ covariance + contrastive losses)")
print()

# =============================================================================
# STAGE 1: Adversarial Autoencoder Training
# =============================================================================
print("🔄 **Stage 1: Training Adversarial Autoencoder**")
print("   Learning basic image reconstruction and Gaussian latent space...")

stage1_results = train_autoencoder(
    config_stage1, 
    x_train_aug, 
    x_val
)

# Extract trained models
encoder = stage1_results['encoder']
decoder = stage1_results['decoder']
discriminator = stage1_results['discriminator']

print("\n✅ **Stage 1 Completed!**")
print(f"   • Final reconstruction loss: {stage1_results['recon_loss'][-1]:.4f}")
print(f"   • Final adversarial loss: {stage1_results['adv_loss'][-1]:.4f}")

# =============================================================================
# STAGE 2: AI4CellFate Training (Latent Space Engineering)
# =============================================================================
print("\n🔄 **Stage 2: AI4CellFate Training (Latent Space Engineering)**")
print("   Adding covariance and contrastive losses for feature disentanglement...")

stage2_results = train_cellfate(
    config_stage2,
    encoder,
    decoder, 
    discriminator,
    x_train_aug,
    y_train_aug,
    x_val,
    y_val,
    x_test,
    y_test
)

# Update models with final trained versions
encoder = stage2_results['encoder']
decoder = stage2_results['decoder']
discriminator = stage2_results['discriminator']
final_confusion_matrix = stage2_results['confusion_matrix']

print("\n🎉 **Stage 2 Completed!**")
print(f"   • Final reconstruction loss: {stage2_results['recon_loss'][-1]:.4f}")
print(f"   • Final adversarial loss: {stage2_results['adv_loss'][-1]:.4f}")
print(f"   • Final covariance loss: {stage2_results['cov_loss'][-1]:.4f}")
print(f"   • Final contrastive loss: {stage2_results['contra_loss'][-1]:.4f}")
print(f"   • Training stopped at epochs: {stage2_results['good_conditions_stop']}")

# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print("\n🏆 **Final Training Results:**")
print("   • Models trained successfully on both stages")
print("   • Latent space engineered for optimal feature disentanglement")
print("   • Ready for latent space visualization and interpretation")

# Display final confusion matrix
if final_confusion_matrix is not None:
    print(f"\n📊 **Final Classification Performance:**")
    print(f"   • Class 0 accuracy: {final_confusion_matrix[0,0]:.3f}")
    print(f"   • Class 1 accuracy: {final_confusion_matrix[1,1]:.3f}")
    print(f"   • Mean diagonal accuracy: {np.mean(np.diag(final_confusion_matrix)):.3f}")

print("\n✅ **AI4CellFate training completed successfully!**")
