# 🌟 **MemBrain-pick Tutorial**
Welcome to the **MemBrain-pick Tutorial**!  
This guide will walk you through the main steps of using MemBrain-pick to localize membrane particles.  
The tutorial uses example data from the paper:  
["Molecular architecture of thylakoid membranes within intact spinach chloroplasts"](https://www.biorxiv.org/content/10.1101/2024.11.24.625035v1.abstract).

---

### 🗂 **Tutorial Outline**
1. 📥 **Install MemBrain-pick**  
   Set up the tools you'll need for this tutorial.
2. 📊 **Load Data from Zenodo**  
   Use our example data to get started.
3. 🛠️ **Preprocess Data**  
   Prepare your data for MemBrain-pick processing.
4. 🖼️ **Visualize Data**  
   Inspect tomographic densities and membrane meshes.
5. 🧠 **Train a MemBrain-pick Model** *(optional)*  
   Learn how to train your own particle-detection model.
6. 🔍 **Predict Particle Positions**  
   Use a trained model to find particle positions.
7. 🎨 **Visualize Results**  
   Explore the predicted positions and distance maps.

---

### 🚀 **Let’s Get Started!**  
Follow along step by step to learn how to use MemBrain-pick!

### ⚡ Enable GPU for Better Performance

To take full advantage of this tutorial, please enable the GPU runtime in Google Colab:

1. Click on **Runtime** in the top menu.
2. Select **Change runtime type**.
3. In the popup:
   - Set **Hardware accelerator** to **GPU**.
   - Click **Save**.
4. Verify that the runtime is connected by checking the top-right corner.

With GPU enabled, the computations will run a bit faster! 🚀

## 🛠️ Step 1: Install MemBrain-pick

To get started, you’ll need to install MemBrain-pick and its dependencies.

### 📖 **Reference**  
For detailed installation instructions, visit the [official GitHub repository](https://github.com/CellArchLab/membrain-pick/blob/master/docs/Installation.md).

---

### ⚙️ **Simplified Installation**  
The basic setup is straightforward! You can install the core packages using the following command:

```
pip install membrain-pick surforama napari
```

### 💡 **Additional Dependencies for Colab**  
Since this tutorial is running in **Google Colab**, we’ll need a few extra packages:
- **`open3d`** and **`plotly`**: To visualize outputs (since Napari doesn’t run on Colab).
- **`pymeshlab`**: For faster mesh generation.

Install everything with the following command:

```
!pip install membrain-pick surforama napari open3d plotly pymeshlab
```

In [2]:
%%capture
# Clone the tutorial scripts repository from GitHub
!git clone https://github.com/CellArchLab/membrain_tutorial_scripts.git

# Install MemBrain-pick and its dependencies
!pip install membrain-pick==0.0.5  # Specific version of MemBrain-pick
!pip install surforama
!pip install napari

# Install additional libraries for Colab-specific visualization and processing
!pip install open3d  # For 3D data visualization
!pip install plotly  # For interactive plots in Colab
!pip install pymeshlab  # Faster mesh generation on Colab

## 📥 Step 2: Load Data from Zenodo

We’ve provided some example data for **MemBrain-seg** and **MemBrain-pick** on **Zenodo**.  
You can find it under the link below:

🔗 [Zenodo Dataset](https://zenodo.org/records/14610597)

---

### 📂 **Data Contents**
The dataset includes:
- 🧪 **Tomogram**: `Tomo0001.mrc`
- 🧩 **Membrane Segmentations**: A folder containing 5 segmentations (`./membranes`)
- 🗂️ **Manual Annotations**: Positions for Photosystem-2 (PSII) stored in `./positions`

---

### 🚀 **Let’s Load the Data**
We’ll load the dataset in the next cell so it’s ready for preprocessing.


In [3]:
# Import the function to load tutorial data
from membrain_tutorial_scripts.membrain_tutorial_scripts import load_tutorial_data

# Download and extract the tutorial dataset from Zenodo
load_tutorial_data()

# Change the working directory to the extracted data folder
%cd data5mbs

Downloading data from Zenodo. This can take few minutes.
Unzipping downloaded data.

Done. Files in the tutorial folder:
membranes
positions
Tomo0001.mrc
/content/data5mbs


## 🛠️ Step 3: Preprocess Data

In this step, we’ll [prepare the data](https://github.com/CellArchLab/membrain-pick/blob/master/docs/Data_Preparation.md) for training and prediction using MemBrain-pick.

---

### ⚙️ **What Happens During Preprocessing?**

1. 🧩 **Membrane Segmentations** are converted into triangular meshes.
2. 🎥 **Tomographic Densities** are projected onto the vertices of the meshes along their normal vectors.

---

### 📦 **Result**  
After preprocessing, you’ll get `.h5` containers that store all the information required for the neural network in the next steps.

---

### 🚀 **Run the Preprocessing**  
We’ll preprocess the data in the next cell. Simply run the provided command to prepare your dataset!


In [4]:
!membrain_pick convert_mb_folder --mb-folder ./membranes --tomo-path ./Tomo0001.mrc --input-pixel-size 14.08 --pymeshlab-meshing

Processing ./membranes/T1S1M12.mrc
Processing ./membranes/T1S1M19.mrc
Processing ./membranes/T1S1M14.mrc
Processing ./membranes/T1S1M17.mrc
Processing ./membranes/T1S1M16.mrc


## 🖼️ Step 4: Visualize Data

### 📖 **Standard Visualization**  
Typically, you would visualize your `.h5` containers in **Surforama**, a specialized tool for inspecting membrane meshes and tomographic densities in 3D. To start Surforama, you can use the following command locally (not in Colab):

```
membrain_pick surforama
```

#### 🌟 Surforama Advantages:
- 🖌️ **Detailed Visualizations**: Surforama provides a polished and highly accurate view of the membrane and tomographic densities.
- 🎯 **Reduced Distortions**: Visualizations are free from the distortions present in the simplified approach shown here.
- 🧩 **Interactive Exploration**: Allows smooth navigation and inspection of 3D structures for better analysis.

---

### 🔍 **Colab Visualization**
Since **Surforama** and **Napari** cannot run in Colab, we will use a simpler approach to visualize the data here. Specifically:
- A **3D point cloud** of the membrane surface with tomographic densities projected onto its vertices.
- **Ground truth particle positions** will be overlaid on the same surface for comparison.

---

### ⚠️ **Important Note**
- This visualization is for **concept demonstration only** and may contain **distortions**.
- For precise and polished visualizations, **always use Surforama locally** with the command:

```
membrain_pick surforama
```

In [5]:
# Import shared utilities used throughout the tutorial
from membrain_tutorial_scripts.membrain_tutorial_scripts import (
    visualize_membranes,
    get_checkpoint_file,
    load_membrane_data_raw
)

# Load raw membrane data for visualization
out_dict = load_membrane_data_raw("T1S1M12")

# Extract points, tomographic values, and ground truth positions
points = out_dict["points"]
tomo_values = out_dict["tomo_values"]
positions = out_dict["positions"]

# Visualize the membrane with tomographic densities and ground truth positions
visualize_membranes(
    points=[points, points],
    positions=positions,
    colors=[tomo_values * -1, tomo_values * -1],
    color_scales=['Greys', 'Greys'],
    z_shifts=[0, 100]
)


## 🧠 Step 5: Train a MemBrain-pick Model (Optional)

Model training can take time, so this step is optional. You can skip to Step 6 and use the pretrained model if preferred.  
For a detailed explanation, refer to the [training documentation](https://github.com/CellArchLab/membrain-pick/blob/master/docs/Training.md).

---

### ⚙️ **What Happens During Training?**
1. 🧩 **Initialization**:
   - Ground truth positions are projected onto the membrane surface.
   - The surface is partitioned into smaller chunks for efficient processing.
2. 🖥️ **Device Check**: GPU availability is displayed.
3. 🧠 **Model Summary**: Network parameters are printed.
4. 🔄 **Epochs**: Training and validation losses are printed for each epoch.

---

### ⏳ **Time Considerations**
- Training starts slow but speeds up after the first epoch.  
- The default **10 epochs** will run quickly but may not yield good results.  
- For better performance, consider training for **200 epochs**.

---

### 🚀 **Ready to Train?**
First, we’ll prepare the training data folder, then start the training process.


In [6]:
# Import necessary utilities for this step
from membrain_tutorial_scripts.membrain_tutorial_scripts import (
    create_membrain_pick_training_data
)
 
# Prepare training data
create_membrain_pick_training_data()

# Train the model (example with 3 epochs for a quick run)
!membrain_pick train --data-dir ./training_data --input-pixel-size 14.08 --max-epochs 3 --no-verbose

Training data created.
Loading  membranes into dataset.
Projecting points to nearest hyperplane.
Projecting points to nearest hyperplane.
Projecting points to nearest hyperplane.
Precomputing partitioning of the mesh.
Computing partitioning for membrane 0.
Precomputing partitioning for membrane 0 with 33488 faces.
100% 33488/33488 [00:01<00:00, 21051.56it/s]
Computing partitioning for membrane 1.
Precomputing partitioning for membrane 1 with 33484 faces.
100% 33484/33484 [00:02<00:00, 13426.80it/s]
Computing partitioning for membrane 2.
Precomputing partitioning for membrane 2 with 33466 faces.
100% 33466/33466 [00:02<00:00, 13256.37it/s]
Loading  membranes into dataset.
Projecting points to nearest hyperplane.
Projecting points to nearest hyperplane.
Precomputing partitioning of the mesh.
Computing partitioning for membrane 0.
Precomputing partitioning for membrane 0 with 33472 faces.
100% 33472/33472 [00:01<00:00, 20047.75it/s]
Computing partitioning for membrane 1.
Precomputing part

## 🔍 Step 6: Predict Using the Trained Model

In this step, we’ll [perform predictions](https://github.com/CellArchLab/membrain-pick/blob/master/docs/Prediction.md) to identify particle positions on the membranes.

You can use:
- The **provided pretrained model** for faster results.
- The **latest checkpoint** from your training (if completed). Note: Results may vary depending on training duration.

---

### 📜 **Printed Outputs During Prediction**
1. **Mesh Partitioning**: The script partitions membranes into smaller chunks for processing.
2. **Heatmap Generation**: Predicts heatmaps for each chunk (typically 40–50 per membrane) and combines into a single one.
3. **Clustering**: Performs clustering on heatmaps to identify particle positions.
4. **Completion Messages**: Confirms saved outputs for each membrane.

---

### 📦 **File Outputs**
- `.h5` containers: Include the mesh, predicted distance maps, and particle positions.
- `.star` files: Store the predicted particle positions.

---

### ⚙️ **Run the Prediction**
Use the following code to perform predictions:

In [18]:
# Let's create an input folder for the prediction
!mkdir pred_inputs
!scp ./mesh_data/Tomo0001_T1S1M19.h5 pred_inputs/

In [19]:
# Import necessary utilities for this step
from membrain_tutorial_scripts.membrain_tutorial_scripts import (
    get_checkpoint_file
)

# Set whether to use the newly trained model or the provided pretrained model
use_newly_trained_checkpoint = False  # Set to True to use the latest checkpoint from your training

# Define the directory containing preprocessed data
data_dir = "./pred_inputs"

# Get the appropriate checkpoint path (latest or pretrained)
ckpt_path = get_checkpoint_file(latest=use_newly_trained_checkpoint)

# Perform the prediction using MemBrain-pick
!membrain_pick predict --data-dir {data_dir} --ckpt-path {ckpt_path} --input-pixel-size 14.08 --mean-shift-bandwidth 75 --mean-shift-score-threshold 7.0

Loading  membranes into dataset.
Precomputing partitioning of the mesh.
Loaded partitioning for membrane 0.
  return torch.sparse.FloatTensor(
46it [00:02, 20.99it/s]
Performing mean shift...
Clustering found 76 clusters.
Saving to ./predict_output/Tomo0001_T1S1M19.h5


## 🎨 Step 7: Visualize Prediction Results

In this step, we’ll load the data stored in the prediction `.h5` containers. These files contain:
- 🧩 **Mesh**: The processed membrane surface.
- 🔥 **Predicted Network Scores**: Heatmaps generated by the network.
- 🎯 **Particle Positions**: Obtained through Mean Shift clustering.

---

### 📊 **Visualization Overview**
- **Upper Panel**: Displays the **predicted heatmap**.
- **Lower Panel**: Shows the **original tomographic densities** for comparison.

This helps you understand how well the model predicted particle positions relative to the input data.

---

### 🚀 **Run the Visualization**
The following cells will load the `.h5` data and generate the visualization.


In [20]:
# Import necessary utilities for this step
from membrain_tutorial_scripts.membrain_tutorial_scripts import (
    load_membrane_data_pred
)

# Load prediction data for visualization
# "T1S1M19" is an example membrane from the prediction dataset (not used for training)
membrane_dict = load_membrane_data_pred("T1S1M19")

# Extract key components from the prediction data
points = membrane_dict["points"]          # 3D coordinates of the membrane mesh
tomo_values = membrane_dict["tomo_values"]  # Original tomographic density values
positions = membrane_dict["positions"]    # Predicted particle positions (Mean Shift clustering)
scores = membrane_dict["scores"]          # Predicted network scores (heatmap values)

visualize_membranes(
    points = [points, points],
    positions = positions,
    colors = [tomo_values * -1, scores * -1],
    color_scales = ['Greys', 'RdBu'],
    z_shifts = [0, 100]
)

## 🎉 **Congrats!**

You’ve successfully completed the MemBrain-pick tutorial! 🎊  
We hope this guide helped you get started with analyzing membrane particle data.

---

### 💬 **Feedback and Support**
If you encounter any issues or have suggestions, please let us know via the **GitHub Issues** page:  
[📌 Report Issues or Provide Feedback](https://github.com/CellArchLab/membrain-pick/issues)

---

### 🔗 **Additional Resources**
For more details, examples, and documentation, visit the project’s GitHub repository:  
[📘 MemBrain-pick GitHub Repository](https://github.com/CellArchLab/membrain-pick/)

---

### 🚀 **Next Steps**
- Try using your own data with MemBrain-pick.
- Experiment with training the model for longer epochs.

Thank you for following along! 😊