**Starch Area Quantification by X-net**

This Colab notebook provides a pipeline for quantifying starch granule areas in stomatal guard cells, as visualized by PS-PI staining. We utilized X-net for automated starch segmentation; comprehensive details regarding the model training and architecture are provided in methods section.

**Note on Model Generalizability:**
This model was optimized for images acquired with a Nikon C1 confocal microscope (100x objective, $50.54 \times 50.54$ $\mu\text{m}$ field of view). When applying this pipeline to images from different optical setups, we strongly recommend validating the automated results against manual analysis in ImageJ, following the protocol established by Flütsch et al. (2018), to ensure quantitative accuracy.

**Reference**
Flütsch S., Distefano L., and Santelia D. Quantification of starch in guard cells of Arabidopsis thaliana. Bio-protocol 8, e2920 (2018).


**Preparation for Analysis**

1, Environment Setup:
Ensure that your trained model file (model.pth) and the folder containing your confocal images are uploaded to your Google Drive.

2, Drive Mounting and Authentication:
Execute the first two code cells. You will be prompted to authenticate your Google account and grant the necessary permissions.

3, Accessing Data:
Once authenticated, the Google Drive directory (/content/drive/) will appear in the "Files" tab on the left sidebar. This allows the notebook to directly read your input data and save the analysis results back to your Drive.

In [None]:
%pip install -U --no-cache-dir imagecodecs tifffile

Collecting imagecodecs
  Downloading imagecodecs-2026.1.14-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (20 kB)
Collecting tifffile
  Downloading tifffile-2026.2.16-py3-none-any.whl.metadata (30 kB)
Downloading imagecodecs-2026.1.14-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (24.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.7/24.7 MB[0m [31m135.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tifffile-2026.2.16-py3-none-any.whl (233 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.6/233.6 kB[0m [31m257.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tifffile, imagecodecs
  Attempting uninstall: tifffile
    Found existing installation: tifffile 2026.1.28
    Uninstalling tifffile-2026.1.28:
      Successfully uninstalled tifffile-2026.1.28
Successfully installed imagecodecs-2026.1.14 tifffile-2026.2.16


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


Import libraries

Please run the code below to import the necessary libraries.

In [None]:
import glob
import cv2
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.functional as TF
import torch.utils.data as data
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import numbers
import numpy as np
import argparse
import matplotlib.pyplot as plt
import sys
import seaborn as sns
from torch.autograd import Variable
from tqdm import tqdm
from natsort import natsorted
from PIL import Image
from enum import Enum
from torch import nn

Definition of DataLoader

In [None]:
# compose
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, lbl):
        for t in self.transforms:
            img, lbl = t(img, lbl)
        return img, lbl

# to Tensor
class ToTensor(object):
    def __init__(self, normalize=True, target_type='uint8'):
        self.normalize = normalize
        self.target_type = target_type
    def __call__(self, pic, lbl):
        if self.normalize:
            return TF.to_tensor(pic), torch.from_numpy( np.array( lbl, dtype=self.target_type) )
        else:
            return torch.from_numpy( np.array( pic, dtype=np.float32).transpose(2, 0, 1) ), torch.from_numpy( np.array( lbl, dtype=self.target_type) )

class DataLoader(data.Dataset):

    def __init__(self, img_path, transform=None):

        self.image_path = img_path

        self.image_list = sorted(os.listdir(self.image_path))
        self.transform = transform


    def __getitem__(self, index):

        image_name = self.image_list[index]


        image = Image.open(self.image_path + "/{}".format(image_name)).convert("RGB")


        if self.transform:

            image, label = self.transform(image,image)

        return image, image_name



    def __len__(self):

        return len(self.image_list)

Definition of Neural Network model

In [None]:
class ChannelSELayer(nn.Module):
    def __init__(self, num_channels, reduction_ratio=4):
        super(ChannelSELayer, self).__init__()
        num_channels_reduced = num_channels // reduction_ratio
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        batch_size, num_channels, H, W = input_tensor.size()
        # Average along each channel
        squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)

        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor))
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

        a, b = squeeze_tensor.size()
        output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
        return output_tensor


class SpatialSELayer(nn.Module):
    def __init__(self, num_channels):
        super(SpatialSELayer, self).__init__()
        self.conv = nn.Conv2d(num_channels, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor, weights=None):
        # spatial squeeze
        batch_size, channel, a, b = input_tensor.size()

        if weights:
            weights = weights.view(1, channel, 1, 1)
            out = F.conv2d(input_tensor, weights)
        else:
            out = self.conv(input_tensor)
        squeeze_tensor = self.sigmoid(out)

        # spatial excitation
        output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, a, b))

        return output_tensor


class ChannelSpatialSELayer(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ChannelSpatialSELayer, self).__init__()
        self.cSE = ChannelSELayer(num_channels, reduction_ratio)
        self.sSE = SpatialSELayer(num_channels)

    def forward(self, input_tensor):
        output_tensor = self.cSE(input_tensor) + self.sSE(input_tensor)
        return output_tensor


class BaseBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(BaseBlock, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, in_ch, 3, 1, 1), nn.BatchNorm2d(in_ch), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch, in_ch, 3, 1, 1), nn.BatchNorm2d(in_ch), nn.ReLU())

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)

        return x


class Encoder(nn.Module):
    def __init__(self, ch=64):
        super(Encoder, self).__init__()
        self.block = BaseBlock(in_ch=ch, out_ch=ch)
        self.conv = nn.Sequential(nn.Conv2d(ch, ch*2, 3, 1, 1), nn.BatchNorm2d(ch*2), nn.ReLU())
        self.se = ChannelSpatialSELayer(num_channels=ch*2)
        self.pool = nn.MaxPool2d(2,2)

    def forward(self, x):
        x = self.block(x)
        x = self.conv(x)
        x = self.se(x)
        x = self.pool(x)

        return x


class Decoder(nn.Module):
    def __init__(self, ch=1024, dilation=1):
        super(Decoder, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(ch, ch//4, 3, 2, 1, dilation)
        self.block = BaseBlock(in_ch=ch//4, out_ch=ch//4)
        self.se = ChannelSpatialSELayer(num_channels=ch//4)

    def forward(self, x1, x2):
        x = torch.cat((x1,x2),1)
        x = self.deconv1(x)
        x = self.block(x)
        x = self.se(x)

        return x


class XNet(nn.Module):
    def __init__(self, in_ch, out_ch, ch=64, dil=1):
        super(XNet, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, ch, 3, 1, 1, dilation=dil)
        self.bn1 = nn.BatchNorm2d(ch)

        self.enc1_1 = Encoder(ch)
        self.enc1_2 = Encoder(ch*2)
        self.enc1_3 = Encoder(ch*4)
        self.dec1_1 = Decoder(1536, dilation=dil)
        self.dec1_2 = Decoder(640, dilation=dil)
        self.dec1_3 = Decoder(288, dilation=dil)

        self.center_conv = BaseBlock(in_ch=ch*8*2, out_ch=ch*8*2)

        self.enc2_1 = Encoder(ch)
        self.enc2_2 = Encoder(ch*2)
        self.enc2_3 = Encoder(ch*4)
        self.dec2_1 = Decoder(1536, dilation=dil)
        self.dec2_2 = Decoder(640, dilation=dil)
        self.dec2_3 = Decoder(288, dilation=dil)

        self.out1_conv = nn.Conv2d(72, out_ch, 1, 1)
        self.out2_conv = nn.Conv2d(72, out_ch, 1, 1)
        self.out3_conv = nn.Conv2d(144, out_ch, 1, 1)

    def forward(self,x):
        h = F.relu(self.bn1(self.conv1(x)))
        h1 = h
        h2 = h

        h1_e1 = self.enc1_1(h1)
        h2_e1 = self.enc2_1(h2)

        h1_e2 = self.enc1_2(h1_e1)
        h2_e2 = self.enc2_2(h2_e1)

        h1_e3 = self.enc1_3(h1_e2)
        h2_e3 = self.enc2_3(h2_e2)

        cat = torch.cat((h1_e3,h2_e3),1)
        hc = self.center_conv(cat)

        h1_d1 = self.dec1_1(hc,h1_e3)
        h2_d1 = self.dec2_1(hc, h2_e3)

        h1_d2 = self.dec1_2(h1_d1,h1_e2)
        h2_d2 = self.dec2_2(h2_d1, h2_e2)

        h1_d3 = self.dec1_3(h1_e1,h1_d2)
        h2_d3 = self.dec2_3(h2_e1,h2_d2)
        cat = torch.cat((h1_d3, h2_d3),1)

        out1 = self.out1_conv(h1_d3)
        out2 = self.out2_conv(h2_d3)
        out3 = self.out3_conv(cat)

        return out1, out2, out3, cat

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


 ↓ ① Specify the image folder path (Update this for each analysis)

In the code cell below, replace the empty brackets in list = [] with the path to the folder containing the images you wish to analyze.

Follow these steps to copy the folder path:

1, Click the Files icon in the left sidebar.

2, Navigate to drive > MyDrive.

3, Right-click the folder containing your images.

4, Select "Copy path".

5, Paste the path inside the brackets of list = [].

Example:
If your image file is located at:
/content/drive/MyDrive/Analysis Images/Data A/img_a1.tif
You must specify the path up to the folder level ("Data A"):
list = ["/content/drive/MyDrive/Analysis Images/Data A"]

Note:

Enclose each path in quotation marks ("").

If you enter multiple paths, separate them with commas.

Example: list = ["path_A", "path_B", "path_C"]

(You may add line breaks after each comma for better readability.)

Please ensure you re-run this cell whenever you change the path.

In [None]:
#①Specify the path where the image exists
def path_list():
  #Replace the path where the images to be analyzed exists
    list = [
        "/content/drive/MyDrive/XXX",


    ]

    return list

# ↓ ② Preprocessing: Crop images to 512 x 512

This step crops the images to 512 x 512 pixels. The cropped images will be saved in a new folder named crop_img created within each source directory.

During execution, the path of the image currently being processed and the progress status will be displayed in the output area below the cell.

In [None]:
#②Crop images to 512x512 as a preprocessing step
folda_path = path_list()

for i in folda_path:
    print(i)
    if os.path.exists(i)==False:
            print("######################################################################")
            print("No file",i)#It is indicated if the path is wrong
            print("######################################################################")


    if os.path.exists("{}/{}".format(i, "crop_img"))==False:
        os.makedirs("{}/{}".format(i, "crop_img"))

    imgs = glob.glob(i + "/*.tif")

    for j in tqdm(imgs):

        img = cv2.imread(j)

        file_name = j.split("/")[-1]

        h, w, c = img.shape

        if h < 512:
            space = 512-h
            img = cv2.copyMakeBorder(img, 0, space, 0, 0, cv2.BORDER_CONSTANT, value=(0,0,0))
        if w < 512:
            space = 512-w
            img = cv2.copyMakeBorder(img, 0, 0, 0, space, cv2.BORDER_CONSTANT, value=(0,0,0))


        crop = img[0 : 512, 0 : 512]

        cv2.imwrite("{}/{}/{}".format(i, "crop_img", file_name), crop)



 # ↓ ③ Segmentation: Identify starch granules (Update model path for each analysis)

In the code cell below, replace the default path model_path = "model.pth" with the actual file path where your model is stored.

Example:
model_path = "/content/drive/MyDrive/model.pth"

How to set the model path:

1, Click the Files icon in the left sidebar.

2, Navigate to drive > MyDrive.

3, Locate the model file (model.pth) and right-click it.

4, Select "Copy path".

5, Paste the path between the quotation marks in model_path = "".
(Note: The path must be enclosed in quotation marks.)

Output:
Upon execution, two new folders named segmentation and feature_map will be created within the source directory. The resulting segmentation images and visualization maps will be saved in their respective folders.

Progress:
During execution, the path of the image currently being processed and the progress status will be displayed in the output area below the cell.

In [None]:
#③A program to conduct segmentation
from tifffile import imread
import imagecodecs # Add this line to explicitly import imagecodecs

# compose
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, lbl):
        for t in self.transforms:
            img, lbl = t(img, lbl)
        return img, lbl

# to Tensor
class ToTensor(object):
    def __init__(self, normalize=True, target_type='uint8'):
        self.normalize = normalize
        self.target_type = target_type
    def __call__(self, pic, lbl):
        if self.normalize:
            return TF.to_tensor(pic), torch.from_numpy( np.array( lbl, dtype=self.target_type) )
        else:
            return torch.from_numpy( np.array( pic, dtype=np.float32).transpose(2, 0, 1) ), torch.from_numpy( np.array( lbl, dtype=self.target_type) )

class DataLoader(data.Dataset):

    def __init__(self, img_path, transform=None):

        self.image_path = img_path

        self.image_list = sorted(os.listdir(self.image_path))
        self.transform = transform


    def __getitem__(self, index):

        image_name = self.image_list[index]


        image = Image.open(self.image_path + "/{}".format(image_name)).convert("RGB")


        if self.transform:

            image, label = self.transform(image,image)

        return image, image_name



    def __len__(self):

        return len(self.image_list)




def dataload(path):
    test_transform = Compose([ToTensor(),
                                 ])
    test_dataset = DataLoader(img_path = path, transform=test_transform)


    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True)

    return test_loader



def Save_fmap(x, img_name):
    x = x.mean(dim=1)
    x = x.cpu().numpy()
    x = np.squeeze(x)
    x = min_max(x)
    x = x[np.newaxis,:,:]
    x = x[0]
    x = cv2.applyColorMap(np.uint8(x*255.0), cv2.COLORMAP_JET)
    cv2.imwrite("{}/{}/map_{}.png".format(i, "feature_map", img_name), x)

def Save_image(input_img, output, img_name):

    output = np.argmax(output,axis=1)
    output = output[0]
    out_array = np.zeros((image_size[1],image_size[2],3))
    out_array[output==0] = [1.0, 0.0, 0.0]
    out_array[output==1] = [0.0, 0.0, 0.0]
    out_array = cv2.cvtColor(np.uint8(out_array*255.0), cv2.COLOR_BGR2RGB)
    cv2.imwrite("{}/{}/seg_{}.png".format(i, "segmentation", img_name), out_array)


def min_max(x, axis=None):
    min = x.min(axis=axis, keepdims=True)
    max = x.max(axis=axis, keepdims=True)
    result = (x-min)/(max-min)
    return result


def test():
    #Replace model.pth
    model_path = "/content/drive/MyDrive/model.pth"

    model.load_state_dict(torch.load(model_path))
    model.eval()


    with torch.no_grad():
      for batch_idx, (inputs, img_name) in enumerate(tqdm(test_loader)):




        inputs = inputs.to(device, non_blocking=True)


        inputs = Variable(inputs)


        out1, out2, out3, y  = model(inputs)


        output = F.softmax(out3, dim=1)
        output = output.cpu().numpy()
        inputs = inputs.cpu().numpy()

        img_name = ''.join(img_name)
        img_name = img_name.replace(".tif", "")


        Save_image(inputs, output, img_name)
        Save_fmap(y, img_name)

if __name__ == "__main__":
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    folda_path = path_list()
    for i in folda_path:
        print(i)


        if os.path.exists("{}/{}".format(i, "segmentation"))==False:
            os.makedirs("{}/{}".format(i, "segmentation"))
        if os.path.exists("{}/{}".format(i, "feature_map"))==False:
            os.makedirs("{}/{}".format(i, "feature_map"))


        crop_img_path = "{}/{}".format(i, "crop_img")
        test_loader = dataload(crop_img_path)


        image_size = [3, 512, 512]
        label_list = ["Starch", "Background"]


        model = XNet(in_ch=image_size[0], out_ch=len(label_list))
        model = model.to(device)


        test()

# ↓ ④ Quantification of starch granule area

Upon execution, the program creates a new folder named labeling within the source directory. The analysis results, including images and text files, will be stored in this folder.

During execution, the path of the image currently being processed and the progress status will be displayed in the output area below the cell.

In [None]:
#④ To quantify starch granule area from segmantation image

def Labeling(img_path, img_list):
    labeling_result = "{}/{}/{}.txt".format(img_path, "labeling", "labeling_result")
    with open(labeling_result, mode = 'w') as f:
        pass
    for i, path in enumerate(tqdm(img_list)):
        img_name = path.split("/")[-1].replace("seg_", "").replace(".png", "")
        with open(labeling_result, mode = 'a') as f:
            f.write("\n%s\n" % (img_name))

        img = cv2.imread(path)


        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)


        gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]


        label = cv2.connectedComponentsWithStats(gray)


        n = label[0] - 1
        data = np.delete(label[2], 0, 0)
        center = np.delete(label[3], 0, 0)


        color_src = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)

        sum_s = 0

        for j in range(n):

            x0 = data[j][0]
            y0 = data[j][1]
            x1 = data[j][0] + data[j][2]
            y1 = data[j][1] + data[j][3]
            cv2.rectangle(color_src, (x0, y0), (x1, y1), (0, 255, 255))


            cv2.putText(color_src,
                        str(j + 1),
                        (x0, y1 + 8),
                        cv2.FONT_HERSHEY_PLAIN,
                        0.5,
                        (0, 0, 255)
                        )
            cv2.putText(color_src,
                        "ID:{} S: ".format(j+1) +str(data[j][4]),
                        (0, 10*(j+1)),
                        cv2.FONT_HERSHEY_PLAIN,
                        0.7,
                        (0, 0, 255)
                        )
            sum_s += data[j][4]
            with open(labeling_result, mode = 'a') as f:
                f.write("ID:%d\t%d\n" % (j+1, data[j][4]))


        with open(labeling_result, mode = 'a') as f:
                f.write("Total:\t%d\n" % (sum_s))

        cv2.imwrite("{}/{}/labeling_{}.png".format(img_path, "labeling", img_name), color_src)

folda_path = path_list()
for i in folda_path:
    print(i)


    if os.path.exists("{}/{}".format(i, "labeling"))==False:
        os.makedirs("{}/{}".format(i, "labeling"))

    seg_img_path = "{}/{}".format(i, "segmentation")
    seg_img_list = natsorted(glob.glob(seg_img_path + "/*.png"))

    Labeling(i, seg_img_list)

**7. Calculation of starch granule area**

The file labeling_result.txt is located in the labeling folder. When you open this file in Microsoft Excel:

Column A displays the ID numbers of the starch granules identified by segmentation.

Column B shows the calculated area (in pixels) of each granule.

The value in Column B, corresponding to the "total" label in Column A, represents the sum of the starch granule areas within the guard cells.
To determine the actual area in square micrometers ($\mu m^2$), divide the pixel values obtained from the labeling results by the pixel density (pixels/$\mu m^2$) of the analyzed image.

**Note:**

Multiple Stomata: If an image contains multiple stomata, starch granules from all stomata will be detected and included in the total.

False Positives: Background fluorescence or starch granules in mesophyll cells may be inadvertently detected. Please review the segmentation and feature_map images to verify the results and avoid erroneous measurements.


If you hava any question please contact to Dr. Atsushi Takemiya (take.pcs at yamaguchi-u.ac.jp), Dr. Kazuhiro Hotta(kazuhotta at meijo-u.ac.jp), Dr. Shota Yamauchi (shyamauchi at rs.tus.ac.jp)