# 🧠 PyTorch Connectomics Tutorial

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zudi-lin/pytorch_connectomics/blob/v2.0/notebooks/PyTorch_Connectomics_Tutorial.ipynb)

Welcome! This notebook will help you:
1. ✅ **Install** PyTorch Connectomics on Google Colab
2. 🎯 **Run a demo** with synthetic data
3. 🔬 **Train** a model on real mitochondria segmentation data
4. 📊 **Visualize** results

**Time:** 15-20 minutes

**GPU:** This notebook requires a GPU. Make sure to enable it:
- Runtime → Change runtime type → Hardware accelerator → GPU

---

## 📦 Step 1: Installation

Let's install PyTorch Connectomics and its dependencies.

**This takes ~2 minutes.**

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install PyTorch Connectomics
print("📦 Installing PyTorch Connectomics...\n")

# Clone repository
!git clone https://github.com/zudi-lin/pytorch_connectomics.git
%cd pytorch_connectomics

# Install dependencies (pre-built packages to avoid compilation)
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q numpy h5py cython connected-components-3d
!pip install -q pytorch-lightning monai omegaconf

# Install PyTorch Connectomics
!pip install -q -e .

print("\n✅ Installation complete!")

# Verify installation
import torch
import connectomics
print(f"\n🔧 PyTorch: {torch.__version__}")
print(f"🔧 CUDA available: {torch.cuda.is_available()}")
print(f"🔧 PyTorch Connectomics: {connectomics.__version__}")

## 🎯 Step 2: Quick Demo

Let's verify the installation with a 30-second demo using synthetic data.

This will:
- Generate synthetic 3D volumes
- Train a small 3D U-Net for 5 epochs
- Validate the installation

In [None]:
# Run demo
!python scripts/main.py --demo

## 📥 Step 3: Download Tutorial Data

Let's download the **Lucchi++ mitochondria segmentation dataset** (~100 MB).

This is real electron microscopy data with mitochondria annotations.

In [None]:
# Download Lucchi++ dataset
print("📥 Downloading Lucchi++ dataset...\n")

!mkdir -p datasets
!wget -q --show-progress https://huggingface.co/datasets/pytc/tutorial/resolve/main/Lucchi%2B%2B.zip
!unzip -q Lucchi++.zip -d datasets/
!rm Lucchi++.zip

print("\n✅ Data downloaded!")
!ls -lh datasets/Lucchi++/

## 🔍 Step 4: Visualize Data

Let's look at the training data to understand what we're working with.

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt

# Load training data
with h5py.File('datasets/Lucchi++/train_image.h5', 'r') as f:
    train_image = f['main'][:]
with h5py.File('datasets/Lucchi++/train_label.h5', 'r') as f:
    train_label = f['main'][:]

print(f"Image shape: {train_image.shape}")
print(f"Label shape: {train_label.shape}")
print(f"Image range: [{train_image.min():.3f}, {train_image.max():.3f}]")
print(f"Mitochondria pixels: {(train_label > 0).sum() / train_label.size * 100:.1f}%")

# Visualize a slice
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

slice_idx = train_image.shape[0] // 2

axes[0].imshow(train_image[slice_idx], cmap='gray')
axes[0].set_title('EM Image (slice)')
axes[0].axis('off')

axes[1].imshow(train_label[slice_idx], cmap='gray')
axes[1].set_title('Mitochondria Labels')
axes[1].axis('off')

axes[2].imshow(train_image[slice_idx], cmap='gray')
axes[2].imshow(train_label[slice_idx], cmap='Reds', alpha=0.4)
axes[2].set_title('Overlay')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 🏃 Step 5: Train a Model

Now let's train a 3D U-Net on this data!

We'll use the pre-configured tutorial config and run a **fast-dev-run** first to make sure everything works.

In [None]:
# First, do a fast dev run (1 batch) to check everything works
print("🔧 Running fast-dev-run (1 batch)...\n")
!python scripts/main.py --config tutorials/monai_lucchi++.yaml --fast-dev-run

print("\n✅ Fast-dev-run completed!")

### Full Training (Optional)

For a real training run, we'll train for fewer epochs to fit in Colab's time limits.

**This takes ~10-15 minutes on a T4 GPU.**

In [None]:
# Train for 50 epochs (reduced from 1000 for demo)
print("🏃 Training model (50 epochs)...\n")

!python scripts/main.py \
    --config tutorials/monai_lucchi++.yaml \
    optimization.max_epochs=50 \
    system.training.batch_size=2 \
    data.iter_num_per_epoch=50

print("\n✅ Training complete!")

## 📊 Step 6: View Training Progress

Let's load and visualize the training metrics.

In [None]:
# Install tensorboard
!pip install -q tensorboard

# Load TensorBoard
%load_ext tensorboard

# Find the latest run
import glob
import os

log_dirs = glob.glob('outputs/lucchi++_monai_unet/*/logs')
if log_dirs:
    latest_log = max(log_dirs, key=os.path.getmtime)
    print(f"📊 Loading TensorBoard from: {latest_log}")
    %tensorboard --logdir {latest_log}
else:
    print("⚠️  No logs found. Did training complete?")

## 🧪 Step 7: Test the Model

Let's test the trained model on the test set.

In [None]:
# Find the best checkpoint
import glob
checkpoints = glob.glob('outputs/lucchi++_monai_unet/*/checkpoints/*.ckpt')

if checkpoints:
    # Get the most recent checkpoint
    best_ckpt = max(checkpoints, key=os.path.getmtime)
    print(f"🔍 Found checkpoint: {best_ckpt}\n")
    
    # Run test
    !python scripts/main.py \
        --config tutorials/monai_lucchi++.yaml \
        --mode test \
        --checkpoint {best_ckpt}
else:
    print("⚠️  No checkpoints found. Did training complete?")

## 🎨 Step 8: Visualize Predictions

Let's visualize the model's predictions on test data.

In [None]:
# Load predictions and ground truth
import h5py
import matplotlib.pyplot as plt

# Find prediction file
pred_files = glob.glob('outputs/lucchi++_monai_unet/results/test_im_prediction.h5')

if pred_files:
    pred_file = pred_files[0]
    
    # Load predictions and test data
    with h5py.File(pred_file, 'r') as f:
        predictions = f['main'][:]
    
    with h5py.File('datasets/Lucchi++/test_image.h5', 'r') as f:
        test_image = f['main'][:]
    
    with h5py.File('datasets/Lucchi++/test_label.h5', 'r') as f:
        test_label = f['main'][:]
    
    # Visualize multiple slices
    num_slices = 3
    fig, axes = plt.subplots(num_slices, 4, figsize=(16, num_slices * 4))
    
    for i in range(num_slices):
        slice_idx = (i + 1) * test_image.shape[0] // (num_slices + 1)
        
        axes[i, 0].imshow(test_image[slice_idx], cmap='gray')
        axes[i, 0].set_title(f'Input (slice {slice_idx})')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(test_label[slice_idx], cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(predictions[slice_idx], cmap='gray')
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
        
        # Overlay
        axes[i, 3].imshow(test_image[slice_idx], cmap='gray')
        axes[i, 3].imshow(predictions[slice_idx], cmap='Reds', alpha=0.4)
        axes[i, 3].set_title('Overlay')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n📊 Prediction Stats:")
    print(f"  Predicted mitochondria: {(predictions > 0.5).sum() / predictions.size * 100:.1f}%")
    print(f"  Ground truth: {(test_label > 0).sum() / test_label.size * 100:.1f}%")
else:
    print("⚠️  No predictions found. Did testing complete?")

## 🎓 Next Steps

Congratulations! You've successfully:
- ✅ Installed PyTorch Connectomics
- ✅ Trained a 3D U-Net
- ✅ Tested on real data
- ✅ Visualized results

### Where to go from here?

1. **Try different models:**
   - MedNeXt (state-of-the-art)
   - UNETR (transformer-based)
   - Swin UNETR

2. **Use your own data:**
   - Upload HDF5/TIFF files
   - Create custom config
   - Train on your dataset

3. **Optimize training:**
   - Mixed precision (faster)
   - Deep supervision (better)
   - Advanced augmentations

4. **Learn more:**
   - 📚 [Documentation](https://connectomics.readthedocs.io)
   - 🔧 [GitHub Repository](https://github.com/zudi-lin/pytorch_connectomics)
   - 💬 [Slack Community](https://join.slack.com/t/pytorchconnectomics/shared_invite/zt-obufj5d1-v5_NndNS5yog8vhxy4L12w)

---

### 💡 Tips for Colab

- **Save checkpoints:** Download checkpoints before session expires
- **Use Google Drive:** Mount Drive to save outputs persistently
- **GPU limits:** Colab has usage limits, train in short sessions

### 📝 Feedback

Found this helpful? Have suggestions? Let us know:
- ⭐ Star us on [GitHub](https://github.com/zudi-lin/pytorch_connectomics)
- 🐛 Report issues
- 💬 Join our Slack community

Happy segmenting! 🔬🧠