# STED-FM for Super-Resolution Microscopy

In this notebook we will go through the steps of installing and using the pretrained STED-FM model. We will also demonstrate how to use the model for super-resolution microscopy on a sample image. 

We will use the `STED-FM` repository, which provides a PyTorch implementation of the STED-FM model.

The notebook will cover the following steps:
1. **Installation**: Clone the STED-FM repository and install the required dependencies.
2. **Loading the Model**: Load the pretrained STED-FM model.
3. **Image Retrieval**: Retrieve a sample image from the SO dataset.
4. **Segmentation**: Use the STED-FM model to tarin a segmentation model.
5. **Image Generation**: Use the trained model to generate an image from the sample image.

## Installation

This notebook requires the `STED-FM` repository to be cloned and installed. The notebook will check if the repository is already cloned, and if not, it will clone it. After cloning, it will install the required packages.

The kernel will automatically relaunch after the packages have been installed. You can ignore the warning about the session crashing, as it is necessary for the installation to take effect. The installation may take some minutes.

In [None]:
import os
if os.path.isdir("./STED-FM"):
    !git -C ./STED-FM/ pull
else:
    !git clone https://github.com/FLClab/STED-FM.git
%pip install -e ./STED-FM
exit()

## Loading the model

In this section, we will load the pretrained STED-FM model. The model is pretrained on a dataset of super-resolution microscopy images. The model will be automatically downloaded if it is not already present.

The `global_pool` argument can be set to `patch` to have access to the predicted embeddings for each image patch.
```python
model, cfg = get_pretrained_model_v2(
    name = "mae-lightning-small",
    weights = "MAE_SMALL_STED",
    as_classifier=True,
    global_pool="patch"  # Set to "patch" to access predicted embeddings for each image patch    
)
```

In [None]:
import torch

from stedfm import get_pretrained_model_v2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, cfg = get_pretrained_model_v2(
    name = "mae-lightning-small",
    weights = "MAE_SMALL_STED",
    as_classifier=True
)
model.to(device)
model.eval()

model.eval()
with torch.no_grad():
    img = torch.randn(1, 1, 224, 224).to(device)
    out = model.forward_features(img) # (1, 384)
    print(out.shape)

## Image Retrieval

In the STED-FM paper, we performed an image retrieval task on the SO dataset. The model was used to retrieve images based on their embeddings. In the following section, we will demonstrate how to perform this task using the pretrained model.

We will first download the SO dataset and then use the model to retrieve images based on their embeddings. The embeddings will be used to compute the similarity between images, and the most similar images will be retrieved. 

The retrieval performance will be evaluated using the area under the curve (AUC) metric.


In [None]:
import os
import sys

# Everything is relative to this BASE_PATH in the code
from stedfm.DEFAULTS import BASE_PATH
home = BASE_PATH

IN_COLAB = 'google.colab' in sys.modules

!mkdir -p {home}/evaluation-data
if not os.path.isfile(os.path.join(home, "evaluation-data", "optim-data.zip")):
    !wget -O {home}/evaluation-data/optim-data.zip https://s3.valeria.science/flclab-foundation-models/evaluation-data/optim-data.zip

if not os.path.isdir(os.path.join(home, "evaluation-data", "optim-data")):
    !unzip {home}/evaluation-data/optim-data.zip -d {home}/evaluation-data

In [None]:
import numpy
import torch.nn.functional as F
import torch
from matplotlib import pyplot
from sklearn.metrics import roc_auc_score, average_precision_score
from tqdm.auto import trange
from stedfm.loaders import get_dataset

metric = "auc"

_, _, test_loader = get_dataset(name="optim")

# Embed dataset
embeddings, labels, dataset_indices = [], [], []
N = len(test_loader.dataset)

with torch.no_grad():
    for n in trange(N):
        img = test_loader.dataset[n][0].unsqueeze(0).to(device)
        metadata = test_loader.dataset[n][1]
        label = metadata["label"]
        dataset_idx = metadata["dataset-idx"]
        output = model.forward_features(img)
        embeddings.append(output)
        labels.append(label)
        dataset_indices.append(dataset_idx)
embeddings = torch.cat(embeddings, dim=0)
labels = numpy.array(labels)
dataset_indices = numpy.array(dataset_indices)
assert embeddings.shape[0] == labels.shape[0] == dataset_indices.shape[0]

average_precision = []
for e in trange(embeddings.shape[0]):
    curr_embedding = embeddings[e]
    curr_label = labels[e]
    curr_dataset_idx = dataset_indices[e]
    similarities = F.cosine_similarity(embeddings, curr_embedding.unsqueeze(0), dim=1).cpu().numpy()
    sorted_indices = numpy.argsort(similarities)[::-1]

    query_labels = []

    ## AUC
    for w in sorted_indices:
        data_index = dataset_indices[w]
        query_labels.append(1 if labels[w] == curr_label else 0)
    if numpy.unique(query_labels).shape[0] == 1 and query_labels[0] == 1:
        auc = 1.0
    elif numpy.unique(query_labels).shape[0] == 1 and query_labels[0] == 0:
        auc = 0.0
    else:
        if metric == "auc":
            auc = roc_auc_score(query_labels, similarities[sorted_indices])
        elif metric == "aupr":
            auc = average_precision_score(query_labels, similarities[sorted_indices])
    average_precision.append(auc)

print("AUROC:", numpy.mean(average_precision))


## Segmentation Experiment

In this section, we will train a segmentation model using the STED-FM model. The model will be trained on a dataset of super-resolution microscopy images, the F-actin dataset. The model will be used to segment the images into F-actin rings and fibers. 

We will first download the F-actin dataset and then use the STED-FM model to train a segmentation model.

In [None]:
import os
import sys
import numpy
import torch

from matplotlib import pyplot

# Everything is relative to this BASE_PATH in the code
from stedfm.DEFAULTS import BASE_PATH
home = BASE_PATH

!mkdir -p {home}/segmentation-data
if not os.path.isfile(os.path.join(home, "segmentation-data", "factin-segmentation-data.zip")):
    !wget -O {home}/segmentation-data/factin-segmentation-data.zip https://s3.valeria.science/flclab-foundation-models/segmentation-data/factin-segmentation-data.zip

if not os.path.isdir(os.path.join(home, "segmentation-data", "factin")):
    !unzip {home}/segmentation-data/factin-segmentation-data.zip -d {home}/segmentation-data

In the next cell we will import the necessary libraries and set up the configuration for the segmentation experiment. The configuration will include parameters such as the number of epochs, learning rate, and batch size. 

The parameters that can be modified are contained in the section
```python
####################################################
##################### <UPDATE> #####################
####################################################

RANDOM_SEED = 42
...

####################################################
##################### </UPDATE> #####################
####################################################
```


In [None]:
import sys
insert_to_path = "./STED-FM/experiments/segmentation-experiments"
while insert_to_path not in sys.path:
    sys.path.insert(0, insert_to_path)

import os
import random
import numpy
import torch
import pickle
import shutil
import time
import json

from collections import defaultdict
from torch.utils.data import SubsetRandomSampler
from typing import Any
from tqdm.auto import tqdm
from lightly.utils.scheduler import CosineWarmupScheduler

from main import intensity_scale_
from eval import evaluate_segmentation
from datasets import get_dataset

from stedfm import get_decoder
from stedfm import get_pretrained_model_v2
from stedfm.utils import update_cfg, save_cfg, track_loss
from stedfm.configuration import Configuration
from stedfm.DEFAULTS import BASE_PATH

####################################################
##################### <UPDATE> #####################
####################################################

SAVE_FOLDER = "/content/segmentation-baselines"
RANDOM_SEED = 42
USE_TENSORBOARD = False
LABEL_PERCENTAGE = 1.0

class SegmentationConfiguration(Configuration):

    freeze_backbone: bool = True
    num_epochs: int = 10
    learning_rate: float = 1e-4
    batch_size: int = 32

####################################################
##################### <UPDATE> #####################
####################################################

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def validation_step(model: torch.nn.Module, valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, epoch: int, device: torch.device, writer: torch.utils.tensorboard.SummaryWriter = None):
    is_training = model.training

    model.eval()

    statLossTest = []
    for i, (X, y) in enumerate(tqdm(valid_loader, desc="[----] ")):

        # Reshape
        if isinstance(X, (list, tuple)):
            X = [_X.unsqueeze(0) if _X.dim() == 2 else _X for _X in X]
        else:
            if X.dim() == 3:
                X = X.unsqueeze(1)

        # Send to gpu
        X = X.to(torch.float32)
        X = X.to(device)
        y = y.to(device)

        # Prediction and loss computation
        pred = model.forward(X)
        loss = criterion(pred, y)

        # Keeping track of statistics
        statLossTest.append(loss.item())

        if (i == 0) and USE_TENSORBOARD:
            writer.add_images("Images-test/image", intensity_scale_(X[:16]), epoch, dataformats="NCHW")
            for i in range(cfg.dataset_cfg.num_classes):
                writer.add_images(f"Images-test/label-{i}", y[:16, i:i+1], epoch, dataformats="NCHW")
                writer.add_images(f"Images-test/pred-{i}", pred[:16, i:i+1], epoch, dataformats="NCHW")

        # To avoid memory leak
        torch.cuda.empty_cache()
        del X, y, pred, loss

    if is_training:
        model.train()
    return statLossTest


In [None]:
numpy.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

backbone, cfg = get_pretrained_model_v2(
    name = "mae-lightning-small",
    weights = "MAE_SMALL_STED",
)

training_dataset, validation_dataset, testing_dataset = get_dataset(
    name="factin",
    cfg=cfg,
    use_cache=False
)

segmentation_cfg = SegmentationConfiguration()
for key, values in segmentation_cfg.__dict__.items():
    setattr(cfg, key, values)
print(f"Config: {cfg.__dict__}")

probe = "pretrained"
model_name = "pretrained-"
if cfg.freeze_backbone:
    probe = "pretrained-frozen"
    model_name += "frozen-"
model_name += "MAE_SMALL_STED"
if LABEL_PERCENTAGE < 1.0:
    model_name += f"-{int(LABEL_PERCENTAGE * 100)}%-labels"
model_name += f"-{RANDOM_SEED}"

OUTPUT_FOLDER = os.path.join(SAVE_FOLDER, "mae-lightning-small", "factin", model_name)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Saves and prints the configuration
cfg.save(os.path.join(OUTPUT_FOLDER, "config.json"))
print(cfg)

# Build the UNet model
model = get_decoder(backbone, cfg)
model = model.to(device)

stats = defaultdict(list)
min_valid_loss = numpy.inf
start_epoch = 0

sampler = None
if LABEL_PERCENTAGE < 1.0:
    rng = numpy.random.default_rng(RANDOM_SEED)
    indices = list(range(len(training_dataset)))
    rng.shuffle(indices)
    split = int(numpy.floor(LABEL_PERCENTAGE * len(training_dataset)))
    train_indices = indices[:split]
    sampler = SubsetRandomSampler(train_indices)

print("----------------------------------------")
print("Training Dataset")
print("Dataset size: ", len(training_dataset))
print("Dataset size (with sampler): ", len(sampler) if sampler else len(training_dataset))
print("----------------------------------------")
print("Validation Dataset")
print("Dataset size: ", len(validation_dataset))
print(f"Batch size: {cfg.batch_size}")
print("----------------------------------------")

# Build a PyTorch dataloader.
train_loader = torch.utils.data.DataLoader(
    training_dataset,  # Pass the dataset to the dataloader.
    batch_size=cfg.batch_size,  # A large batch size helps with the learning.
    shuffle=sampler is None,  # Shuffling is important!
    num_workers=int(os.cpu_count()),
    sampler=sampler, drop_last=False
)
valid_loader = torch.utils.data.DataLoader(
    validation_dataset,  # Pass the dataset to the dataloader.
    batch_size=cfg.batch_size,  # A large batch size helps with the learning.
    shuffle=True,  # Shuffling is important!
    num_workers=int(os.cpu_count())
)

# Defines the training budget
num_epochs = cfg.num_epochs
if LABEL_PERCENTAGE < 1.0:
    budget = len(training_dataset) * num_epochs
    num_epochs = int(budget / (LABEL_PERCENTAGE * len(training_dataset)))
    cfg.num_epochs = num_epochs
    print(f"Training budget is updated: {cfg.num_epochs} epochs")

if cfg.num_epochs > 1000:
    cfg.num_epochs = 1000

# Defines the optimizer
if probe == "pretrained-frozen":
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    scheduler = CosineWarmupScheduler(
        optimizer=optimizer, warmup_epochs=0.1*cfg.num_epochs, max_epochs=cfg.num_epochs,
        start_value=1.0, end_value=0.01,
        period=cfg.num_epochs//10
    )
else:
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05, betas=(0.9, 0.999))
    scheduler = CosineWarmupScheduler(
        optimizer=optimizer, warmup_epochs=0.1*cfg.num_epochs, max_epochs=cfg.num_epochs,
        start_value=1.0, end_value=0.01
    )

criterion = getattr(torch.nn, cfg.dataset_cfg.criterion)()


Once the model and the training is configured, we can proceed to train the model on the F-actin dataset. The model will be trained for a specified number of epochs, and the training progress will be logged.

In [None]:
step = start_epoch * len(train_loader)

writer = None
if USE_TENSORBOARD:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join(OUTPUT_FOLDER, "logs"))

for epoch in range(start_epoch, cfg.num_epochs):

    start = time.time()
    print("[----] Starting epoch {}/{}".format(epoch + 1, cfg.num_epochs))

    # Keep track of the loss of train and test
    statLossTrain = []

    # Puts the model in training mode
    model.train()
    for i, (X, y) in enumerate(tqdm(train_loader, desc="[----] ")):

        # Reshape
        if isinstance(X, (list, tuple)):
            X = [_X.unsqueeze(0) if _X.dim() == 2 else _X for _X in X]
        else:
            if X.dim() == 3:
                X = X.unsqueeze(1)

        # Send to gpu
        X = X.to(torch.float32)
        X = X.to(device)
        y = y.to(device)

        # Prediction and loss computation
        pred = model.forward(X)
        loss = criterion(pred, y)

        # Keeping track of statistics
        statLossTrain.append(loss.item())

        # Back-propagation and optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i == 0) and USE_TENSORBOARD:
            writer.add_images("Images-train/image", intensity_scale_(X[:16]), epoch, dataformats="NCHW")
            for i in range(cfg.dataset_cfg.num_classes):
                writer.add_images(f"Images-train/label-{i}", y[:16, i:i+1], epoch, dataformats="NCHW")
                writer.add_images(f"Images-train/pred-{i}", pred[:16, i:i+1], epoch, dataformats="NCHW")

        # To avoid memory leak
        torch.cuda.empty_cache()
        del X, y, pred, loss

        # Puts the model in evaluation mode
        if step % int(25 * 32 / cfg.batch_size) == 0:
            # Validation step
            statLossTest = validation_step(model, valid_loader, criterion, epoch, device, writer)
            for key, func in zip(("testMean", "testMed", "testMin", "testStd"),
                                    (numpy.mean, numpy.median, numpy.min, numpy.std)):
                stats[key].append(func(statLossTest))
                if USE_TENSORBOARD:
                    writer.add_scalar(f"Loss/{key}", stats[key][-1], step)
            stats["testStep"].append(step)
        step += 1

    # Aggregate stats
    for key, func in zip(("trainMean", "trainMed", "trainMin", "trainStd"),
            (numpy.mean, numpy.median, numpy.min, numpy.std)):
        stats[key].append(func(statLossTrain))
        if USE_TENSORBOARD:
            writer.add_scalar(f"Loss/{key}", stats[key][-1], step)

    # scheduler.step(numpy.min(stats["testMean"]))
    scheduler.step()
    stats["lr"].append(scheduler.get_last_lr())
    if USE_TENSORBOARD:
        _lr = stats["lr"][-1]
        if isinstance(_lr, list):
            for i in range(len(_lr)):
                writer.add_scalar(f"Learning-rate/lr-{i}", _lr[i], step)
        else:
            writer.add_scalar(f"Learning-rate/lr", _lr, step)
        writer.add_scalar(f"Epochs/epoch", epoch, step)
    stats["trainStep"].append(step)

    track_loss(
        train_loss=stats["trainMean"],
        val_loss=stats["testMean"],
        val_acc=stats["testAcc"],
        lrates=stats["lr"],
        save_dir=os.path.join(OUTPUT_FOLDER, "training-curves.png")
    )
    # Save if best model so far
    if min_valid_loss > stats["testMean"][-1]:
        min_valid_loss = stats["testMean"][-1]
        savedata = {
            "model" : model.state_dict(),
            "optimizer" : optimizer.state_dict(),
            "stats" : stats,
        }
        torch.save(
            savedata,
            os.path.join(OUTPUT_FOLDER, "result.pt"))

        del savedata


Once the model is trained, we will use it to segment a sample image from the F-actin dataset. The segmentation results will be visualized to demonstrate the performance of the model.

In [None]:
# Build the UNet model.
model = get_decoder(backbone, cfg)
ckpt = torch.load(os.path.join(OUTPUT_FOLDER, "result.pt"), weights_only=False)["model"]
print("Restoring model...")
model.load_state_dict(ckpt)
model = model.to(device)
model.eval()

# Build a PyTorch dataloader.
test_loader = torch.utils.data.DataLoader(
    testing_dataset,  # Pass the dataset to the dataloader.
    batch_size=cfg.batch_size,  # A large batch size helps with the learning.
    shuffle=True,  # Shuffling is important!
    num_workers=0
)

scores = evaluate_segmentation(model, test_loader, savefolder=None, device=device, dataset_name="factin")
with open(os.path.join(OUTPUT_FOLDER, "segmentation-scores.json"), "w") as file:
    json.dump(scores, file, indent=4)

for key, values in scores.items():
    print("Results for", key)
    values = numpy.array(values)

    fig, ax = pyplot.subplots(figsize=(3, 3))
    for i in range(values.shape[1]):
        data = values[:, i]

        # Remove -1 values as they are not valid
        data = data[data != -1]

        print(
                "avg : {:0.4f}".format(numpy.mean(data, axis=0)),
                "std : {:0.4f}".format(numpy.std(data, axis=0)),
                "med : {:0.4f}".format(numpy.median(data, axis=0)),)

        bplot = ax.boxplot(data, positions=[i], widths=0.8)
        for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']:
            pyplot.setp(bplot[element], color='black')

    ax.set(
        xticks = numpy.arange(values.shape[1]), xticklabels=testing_dataset.classes,
        ylim = (0, 1)
    )
    pyplot.show()



## Diffusion Experiment

In the STED-FM paper, we trained a diffusion model that was conditioned on the STED-FM embeddings. In this section, we will use the pretrained diffusion model to generate images from image embeddings. We will use the images available from this dataset
> Deschênes, A., Santiague, J.-G. S., & Lavoie-Cardinal, F. (2025). Confocal- and STED-FLIM images of neuronal proteins [Data set]. Zenodo. https://doi.org/10.5281/zenodo.15438495

In [None]:
# Download the model
import os
import sys
import numpy
import torch

from matplotlib import pyplot

# Everything is relative to this BASE_PATH in the code
from stedfm.DEFAULTS import BASE_PATH
home = BASE_PATH

!mkdir -p {home}/baselines
if not os.path.isfile(os.path.join(home, "baselines", "diffusion-model.zip")):
    !wget -O {home}/baselines/diffusion-model.zip https://s3.valeria.science/flclab-foundation-models/models/diffusion-model.zip

model_path = os.path.join(home, "baselines", "DiffusionModels", "latent-guidance")
if not os.path.isdir(model_path):
    !unzip {home}/baselines/diffusion-model.zip -d {home}/baselines

# Download the data
!mkdir -p {home}/evaluation-data
if not os.path.isfile(os.path.join(home, "evaluation-data", "low-high-quality.zip")):
    !wget -O {home}/evaluation-data/low-high-quality.zip https://s3.valeria.science/flclab-foundation-models/evaluation-data/low-high-quality.zip

if not os.path.isdir(os.path.join(home, "evaluation-data", "low-high-quality")):
    !unzip {home}/evaluation-data/low-high-quality.zip -d {home}/evaluation-data

In [None]:
import sys
insert_to_path = "./STED-FM/experiments/diffusion-experiments"
while insert_to_path not in sys.path:
    sys.path.insert(0, insert_to_path)

from typing import Union
from diffusion_models.diffusion.ddpm_lightning import DDPM
from diffusion_models.diffusion.denoising.unet import UNet
from attribute_datasets import LowHighResolutionDataset

def denormalize(img: Union[numpy.ndarray, torch.Tensor], mu: float = 0.010903545655310154, std: float = 0.03640301525592804) -> Union[numpy.ndarray, torch.Tensor]:
    """
    Denormalizes an image. Note that the parameters mu and sigma seem hard-coded but they have been computed from the training sets and can be found
    in the attribute_datasets.py file.
    """
    return img * std + mu

latent_encoder, model_config = get_pretrained_model_v2(
    name="mae-lightning-small",
    weights="MAE_SMALL_STED",
    as_classifier=True,
)
denoising_model = UNet(
    dim=64,
    channels=1,
    dim_mults=(1,2,4),
    cond_dim=model_config.dim,
    condition_type="latent",
    num_classes=4
)
diffusion_model = DDPM(
    denoising_model=denoising_model,
    timesteps=1000,
    beta_schedule="linear",
    condition_type="latent",
    latent_encoder=latent_encoder,
)

ckpt = torch.load(os.path.join(model_path, "MAE_SMALL_STED", "checkpoint-69.pth"))
diffusion_model.load_state_dict(ckpt["state_dict"])
diffusion_model.to(device)
diffusion_model.eval()

dataset = LowHighResolutionDataset(
    h5path=os.path.join(home, "evaluation-data", "low-high-quality", "testing.hdf5"),
    num_samples=None,
    transform=None,
    n_channels=1,
    num_classes=2,
    classes=["low", "high"]
)
print(len(dataset))

In the next cell, we will select an image from the dataset and use the diffusion model to generate a new image based on the selected image's embedding. The generated image will be visualized to demonstrate the performance of the diffusion model.

In [None]:
index = 0 # Change this index to select a different image
img, metadata = dataset[index]
label = metadata["label"]

img = img.clone().unsqueeze(0).to(device)
latent_code = diffusion_model.latent_encoder.forward_features(img)

sample = diffusion_model.p_sample_loop(
    shape=(img.shape[0], 1, img.shape[2], img.shape[3]),
    cond=latent_code,
    progress=True
)
sample = denormalize(sample)

In [None]:
fig, axes = pyplot.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img.squeeze().cpu().numpy(), cmap="hot")
axes[1].imshow(sample.squeeze().cpu().numpy(), cmap="hot")
axes[0].set_title(f"Original Image (Label: {label})")
axes[1].set_title("Generated Sample")
for ax in axes:
    ax.axis("off")
pyplot.show()