## EMATM0047: Data Science Project
---
### Code Section S2.2: Image segmentation - TransUnet model training section
#### Author: Alan Liu
#### Faculty of Engineering
#### University of Bristol

Input:
1. the cryo-EM picture, 4 stages with 100 .mrc each, 400 files in total.
2. the corresponding mask, 4 stages with 100 .npy each, 400 files in total.

Operation:
1. Construct the DataLoader
2. Constructure TransUnet model for 7 stages
3. Conduct the model training, followed by a particle picking section

Connect to Google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Validate if the mrcfile is installed

In [None]:
# check if mrcfile is installed
# check if ml-collections is installed, to read the configDict
try:
  import mrcfile
  print("Yes")
except:
  print("No")
  !pip install mrcfile
  !pip install ml-collections
  print()
  print("mrcfile and ml-connections installed.")

Here we follow the README file from the TransUnet github repository to finish all preparation works.

[The TransUnet repo](https://github.com/Beckschen/TransUNet)

### Step 1: clone the repo

In [None]:
# make a new folder
!mkdir -p "/content/drive/MyDrive/Final Project/TransUnet"
# switch to this path
%cd "/content/drive/MyDrive/Final Project/TransUnet"
# clone the repo
!git clone https://github.com/Beckschen/TransUNet.git

# The code only need to be run once

### Step 2: install the dependencies

In [None]:
%cd "/content/drive/MyDrive/Final Project/TransUnet"
%cd TransUNet
# install the dependencies
# original file requires torch==1.4.0 but unavailable here
!pip install -r requirements.txt
# thus, try import torch and view its version
try:
  import torch
  print("Torch installed with version:", torch.__version__)
except:
  raise Exception("Torch not installed")

# must run every time

### Step 3: Download the pre-trained weights

In [None]:
# here we use the R50+ViT-B_16
# download
!wget https://storage.googleapis.com/vit_models/imagenet21k/R50%2BViT-B_16.npz
# create new dict
try:
    !mkdir -p "/content/drive/MyDrive/Final Project/TransUnet/model/vit_checkpoint/imagenet21k"
    print("Folder created")
    # move the weight
    !mv R50+ViT-B_16.npz "/content/drive/MyDrive/Final Project/TransUnet/model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz"
    print("Weight moved")
except:
    raise Exception("Encounter error")

# The code only need to be run once

### Step 4: Construct the dataloader
The dataloader can be copied from the Unet part, with a little modifation

1. set the patch_size to 224 specifically, default in TransUnet
2. stack up the mrc to 3 channels, RGB input is mandatory for transUnet
3. for the noramlization method, use the ImageNet normalization with:

 mean: [0.485, 0.456, 0.406]

 std: [0.229, 0.224, 0.225]
4. change the shape of mrc from the (H, W, C) to (C, H, W)

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import mrcfile
from pathlib import Path
import random
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

# set the random seed to make sure the result can be reproduced
seed = 426 # according to my ID
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

class PatchBasedData(Dataset):
    """
    process the original dataset with the patch_based approach

    the process steps are:
    1. Find all MRC and mask files
    1. For file with cells: generate patches around the cells
    2. For file without cells: generate patches randomly
    3. all patches are restricted with the same size of 256 * 256
    """

    def __init__(self, base_dir: str = "/content/drive/MyDrive/Final Project",
                 stages: Optional[List[int]] = None, # which stage to process
                 patch_size: int = 224, # size of patch, Note; this number must be lower than h, w of mask
                 patches_per_image: int = 24, # how many patches are needed for one mrc file
                 bg_patches_per_image: int = 12): # how many are needed, for no-cell mrc file

        self.base_dir = Path(base_dir)
        assert patch_size == 224 # patch_size must be 224
        self.patch_size = patch_size
        self.patches_per_image = patches_per_image
        self.bg_patches_per_image = bg_patches_per_image

        # stage mapping
        # because the name of stage folder of mrc and npy are different
        self.stage_mapping = {1: {'mrc': 'stageI', 'mask': 'stage1'},
                    2: {'mrc': 'stageII', 'mask': 'stage2'},
                    3: {'mrc': 'stageIII', 'mask': 'stage3'},
                    4: {'mrc': 'stageIV', 'mask': 'stage4'}}

        # all 4 stages will be processed as default
        if stages is None:
            stages = [1, 2, 3, 4]
        self.stages = stages

        # load all files
        self.data_pairs = []
        self.load_file_paths()

        # pre-load the data
        # shorten the processing time
        print("Preloading data into memory")
        self.preload_data()

        # generate all patches
        self.patches = []
        self.generate_all_patches()

    def load_file_paths(self):
        """
        load the mrc and mask file
        save them as pairs
        """
        # note: this path is exclusive to my computer
        # it may be changed in different devices
        mrc_base = self.base_dir / "Dataset-processed"
        mask_base = self.base_dir / "Image-segmentation-Level-4"

        # get the stage folder
        for stage in self.stages:
            mrc_dir = mrc_base / self.stage_mapping[stage]['mrc']
            mask_dir = mask_base / self.stage_mapping[stage]['mask']

            # get all mrc files
            mrc_files = sorted([f for f in os.listdir(mrc_dir) if f.endswith('.mrc')])

            # get the npy on the basis of each mrc
            for mrc_file in mrc_files:
                # the path of one mrc
                mrc_path = mrc_dir / mrc_file
                # construct the name of mask with mrc
                mask_file = f"mask_{mrc_file.split('.')[0]}_v4.npy"
                # the path of the mask
                mask_path = mask_dir / mask_file

                # if pair is found, save them with the stage label
                if mrc_path.exists() and mask_path.exists():
                    self.data_pairs.append({'mrc_path': mrc_path,
                                'mask_path': mask_path,
                                'stage': stage})
                else:
                    print(f"Missing file: {mrc_path} or {mask_path}")
                    continue

        # the correct/ideal number: len(self.stages) * 100
        print(f"{len(self.data_pairs)} pairs in total")

    def preload_data(self):
        """
        pre-load the data into memory
        """
        self.cache_mrc = {}
        self.cache_mask = {}
        # load and cache the mrc and mask
        for idx, pair in enumerate(tqdm(self.data_pairs, desc = 'Loading data')):
            # load mrc
            mrc_data = self.load_mrc(pair['mrc_path'])
            # normalize mrc
            mrc_data = self.normalize_mrc(mrc_data)
            # load mask
            mask_data = np.load(pair['mask_path'])

            # cache them
            self.cache_mrc[idx] = mrc_data
            self.cache_mask[idx] = mask_data


    def generate_all_patches(self):
        """
        generate the patches for all files
        """
        print("Generating patches...")

        # number of no cell mask
        no_cell_images = 0

        for idx, pair in enumerate(tqdm(self.data_pairs, desc = 'generating patches')):
            # load the npy
            mask = np.load(pair['mask_path'])
            # the corresponding stage label
            stage = pair['stage']

            # find all positions of the cell and return (rows, columns)
            cell_positions = np.where(mask == stage)

            if len(cell_positions[0]) > 0:
                # cell patch
                self.generate_cell_patches(idx, mask, stage)
            else:
                # no cell patch
                no_cell_images += 1
                self.generate_background_patches(idx, mask, stage)

        print(f"Found {no_cell_images} images without cells")
        print(f"Generated {len(self.patches)} total patches")

    def generate_cell_patches(self, pair_idx, mask, stage):
        """
        generate the patches for mrc that have cells

        argument:
        pair_idx: the index of the pair
        mask: the Numpy array mask
        stage: the stage label
        """
        # get the mask shape
        h, w = mask.shape

        # define the stride, half of the patch size
        stride = self.patch_size // 2

        # the list storing the x,y and ratio
        patch_in_single = []

        # get all patch
        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):

              patch = mask[y:y + self.patch_size, x:x + self.patch_size]

              # valid pixel only
              all_valid = (patch != -1)

              # there could have all -1 patch, in actual running
              valid_sum = int(np.count_nonzero(all_valid))
              if valid_sum == 0:
                continue

              # get the foregraound ratio: (patch == stage) / all_valid
              foreground_pixel = int(((patch == stage) & all_valid).sum())
              ratio = float(foreground_pixel) / float(all_valid.sum())

              # store the info
              patch_in_single.append((x, y, ratio))

        # here we use top-K to select patch
        # 24 from the top
        k_need = self.patches_per_image

        # sorting based on ratio
        patch_in_single.sort(key = lambda t: t[2], reverse = True)
        # topk
        top_K = patch_in_single[:k_need]


        # record the info
        for (x, y, ratio) in top_K:
          self.patches.append({'image_idx': pair_idx,
                    'y': y,
                    'x': x,
                    'stage': stage,
                    'has_cells': True})

    def generate_background_patches(self, pair_idx, mask, stage):
        """
        generate patches for mrc that have no cells

        argument:
        pair_idx: the index of the pair
        mask: the Numpy array mask
        stage: the stage label
        """
        # get the mask size
        h, w = mask.shape

        # get 10 patches randomly
        for _ in range(self.bg_patches_per_image):
            # make sure inside the mask
            y = random.randint(0, h - self.patch_size)
            x = random.randint(0, w - self.patch_size)

            # save the patch
            self.patches.append({'image_idx': pair_idx,
                      'y': y,
                      'x': x,
                      'stage': stage,
                      'has_cells': False})

    def load_mrc(self, path: Path) -> np.ndarray:
        """
        load the mrc file

        argument:
        path: the path of the mrc file
        """
        with mrcfile.open(path, mode = 'r') as mrc:
            data = mrc.data.copy()
        return data

    def normalize_mrc(self, data: np.ndarray) -> np.ndarray:
        """
        normalize the mrc data following the '_load_mrc()'

        argument:
        data: the mrc data to be normalized
        """
        # use the quantile normalization
        # get the 1st percentile and 99th percentile
        p1, p99 = np.percentile(data, [1, 99])
        # clip the data
        data = np.clip(data, p1, p99)
        # then map the data to [0, 1]
        data = (data - p1) / (p99 - p1 + 1e-8)
        return data

    def __len__(self):
        # how many patches in total
        return len(self.patches)

    def __getitem__(self, idx):
        # get the tuple contains ('image_idx','y','x','stage','has_cell')
        patch_info = self.patches[idx]
        # get the index of the pair
        image_idx = patch_info['image_idx']

        # load the data from cache
        mrc_data = self.cache_mrc[image_idx]
        mask_data = self.cache_mask[image_idx]

        # get the coordinates
        y = patch_info['y']
        x = patch_info['x']
        # clip the mrc
        patch_mrc = mrc_data[y:y + self.patch_size, x:x + self.patch_size]
        # clip the mask
        patch_mask = mask_data[y:y + self.patch_size, x:x + self.patch_size]

        # New part
        # preparation: the mrc must be float32 to keep accuracy
        patch_mrc = patch_mrc.astype(np.float32)

        # step 1: stack up to three channels
        if patch_mrc.ndim == 2:
            patch_mrc = np.stack([patch_mrc] * 3, axis = 2) # (H, W, 3)

        # step 2: use Imagenet normalization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        patch_mrc = (patch_mrc - mean) / std

        # step 3: change the (H, W, 3) to (3, H, W)
        patch_mrc = patch_mrc.transpose(2, 0, 1)

        # the background of stage 4 is different from others
        # thus, we allocate a new label to it
        if patch_info['stage'] == 4:
          patch_mask[patch_mask == 0] = 5

        # for "-1" label, map it to 255
        patch_mask[patch_mask == -1] = 255

        # shift to tensor
        # float for mrc, Conv2d
        patch_mrc = torch.from_numpy(patch_mrc).float()
        # long for mask, calculate loss
        patch_mask = torch.from_numpy(patch_mask).long()

        return {'image': patch_mrc,
            'mask': patch_mask,
            'stage': patch_info['stage'],
            'has_cells': patch_info['has_cells']}

def create_dataloaders(base_dir: str = "/content/drive/MyDrive/Final Project",
            stages: Optional[List[int]] = None,
            batch_size: int = 32,
            patch_size: int = 256,
            num_workers: int = 8,
            val_split: float = 0.2,
            min_cell_pixels: int = 200):
    """
    create dataloaders for training and validation

    argument:
    base_dir: the root directory of the dataset
    stages: which stages to process
    batch_size: hwo many patches in each batch
    patch_size: the size of each patch
    num_workers: Number of subprocesses, 0 in windows and changeable in Colab
    val_split: the percentage of the validation set
    """
    # get the patch_based dataset
    dataset = PatchBasedData(base_dir = base_dir, stages = stages, patch_size = patch_size)

    # get the size of training and validation dataset
    total_size = len(dataset)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size

    # split the dataset
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # construct the training dataloader
    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers,
                  pin_memory = True)
    # construct the validation dataloader
    val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers,
                 pin_memory = True)

    return train_loader, val_loader

def calculate_batch_statistics(batch):
    """
    calculate the statistics of a batch

    argument: batch: the batch created by dataloader
    """
    # get all mask patches in a batch
    masks = batch['mask']
    # dict to store
    stats = defaultdict(int)

    for mask in masks:
        # get each label and their counts
        label, counts = torch.unique(mask, return_counts = True)
        for label, counts in zip(label.tolist(), counts.tolist()):
            # record them
            stats[label] += counts
    # the total pixels
    total = sum(stats.values())
    # the valid pixels, by excluding the 255
    valid = total - stats.get(255, 0)
    # the cell pixels
    # [1,2,3,4] are cells
    cell_pixels = sum(v for k, v in stats.items() if k in [1, 2, 3, 4])
    # summary and return
    return {'total_pixels': total,
        'valid_pixels': valid,
        'cell_ratio': cell_pixels / valid,
        'patches_with_cells': sum(batch['has_cells'])}

Next, construct the TransUnet model

The origianl repository provides the configuration for constructing the model, which should be modifed based on the actual dataset. The items that must be changed are listed below:

1. n_classes: changed to 7, for 4 cells and 2 backgrounds and 1 black border
2. pretrained_path: the path where the pretrained weight is stored. Here I saved the weight R50+ViT-B_16.npz
3. patches.grid: the ViT segments the input image to patches. For a 224*224 input, it will be divided into 16 x 16 patches, so the size of each patch is 14 x 14. The size of grid is determined by the size of input and how many patches we want.


In [None]:
# get the configuration from the offical
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from tqdm import tqdm

# add the path of the model
sys.path.append('/content/drive/MyDrive/Final Project/TransUnet/TransUnet')

from networks.vit_seg_configs import get_r50_b16_config
from networks.vit_seg_modeling import VisionTransformer

def get_vit_config_customed():
  # get the offical version
  official_config = get_r50_b16_config()
  # change the n_classes
  official_config.n_classes = 7
  # change the weights path
  official_config.pretrained_path = '/content/drive/MyDrive/Final Project/TransUnet/model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
  # change the patch size
  official_config.patches.grid = (14, 14)

  return official_config

def transunet(img_size = (224, 224)):
  """
  construct the TransUnet model

  argument:
  img_size: the size of the input image
  """
  custom_config = get_vit_config_customed()
  # make sure the patch size is right
  grid_h = img_size[0] // 16
  grid_w = img_size[1] // 16
  custom_config.patches.grid = (grid_h, grid_w)

  # the name is VisionTransformer
  model = VisionTransformer(custom_config, img_size = img_size, num_classes = custom_config.n_classes)

  # load weights
  weights = np.load(custom_config.pretrained_path, allow_pickle = True)
  model.load_from(weights)

  return model

In [None]:
# define the loss function, use dice loss + crossentropy again
# CrossEntropy: used to ignore the 255 label
# Dice loss: can tackle the category imbalance problem
class DiceLoss(nn.Module):
    """
    calculate the dice loss
    """
    def __init__(self, smooth = 1.0, ignore_index = 255):
        super().__init__()
        # smoothing factor to prevent zero denominator
        self.smooth = smooth
        self.ignore_index = ignore_index

    def forward(self, input, target):
        # get the probability of each label
        # logits: [batch_size, 7 labels, 224, 224]
        all_probs = F.softmax(input, dim = 1)

        # ignore 255
        valid_mask = target != self.ignore_index

        # sum the loss for each label
        total_loss = 0
        classes = input.shape[1]

        # in each class
        for one_class in range(classes):

            # get its probability
            prob = all_probs[:, one_class]

            # shift the ground truth to 1/0
            real = (target == one_class).float()

            # in valid zone only
            prob = prob[valid_mask]
            real = real[valid_mask]

            # calculate the dice coefficeint
            # calculate dice coefficient: (2 * inter * smooth) / (sum(prob) + sum(real) + smooth)
            # inter: the probability on right label
            intersection = (prob * real).sum()
            dice = (2. * intersection + self.smooth) / (prob.sum() + real.sum() + self.smooth)

            # sum to total loss
            total_loss += (1 - dice)

        return total_loss / classes

# combine the dice loss to CrossEntropy
class DiceCELoss(nn.Module):
  """
  conbine the Dice loss and CrossEntropy loss
  """
  def __init__(self, dice_weight = 0.7, ce_weight = 0.3):
    """
    argument:
    dice_weight: the weight of Dice loss
    ce_weight: the weight of CrossEntropy loss
    """
    super().__init__()
    self.dice_weight = dice_weight
    self.ce_weight = ce_weight

    # the crossentropy with ignore index
    self.ce = nn.CrossEntropyLoss(ignore_index = 255)
    # set up the dice loss
    self.dice = DiceLoss()

  def forward(self, input, target):
    # calculate crossentropy Loss
    ce_loss = self.ce(input, target)

    # calculate dice loss
    dice_loss = self.dice(input, target)

    # combine them
    total_loss = self.ce_weight * ce_loss + self.dice_weight * dice_loss

    return total_loss

In [None]:
# training part
def train_model(model, train_loader, val_loader, epochs = 30, dice_weight = 0.6, ce_weight = 0.4, lr = 1e-3,
        device = 'cuda', save_path  = "TransUnet_seg_weights.pth"):
    """
    training section

    argument:
    model: the Unet model
    train_loader: the training dataloader
    val_loader: the validation dataloader
    epochs: how many epochs to train
    dice_weight: the weight of Dice loss
    ce_weight: the weight of CrossEntropy loss
    device: where to train the model, 'cuda' or 'cpu'
    """
    # move the model to cuda
    model = model.to(device)
    # loss function
    criterion = DiceCELoss(dice_weight = dice_weight, ce_weight = ce_weight)
    # for the optimizer
    optimizer = optim.AdamW(model.parameters(), lr = lr, weight_decay = 1e-4)
    # for the scheduler, they use the cosine annealing
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs, eta_min = 0.0)
    # best val loss
    best_val_iou = 0

    # 1,2,3,4 only
    particle_class = [1,2,3,4]

    # include 0, the background
    all_class = [0,1,2,3,4]

    for epoch in range(epochs):
        model.train()
        # the loss of the training set
        train_loss = 0.0
        # calculate the IOU
        train_iou = {stage: {'inter': 0, 'union': 0} for stage in particle_class}

        # use the tqdm to show the progress bar
        pbar = tqdm(train_loader, desc = f'Epoch {epoch + 1} / {epochs} [Train]')
        for batch in pbar:
            # clear the gradient
            optimizer.zero_grad()
            # input the mrc and mask
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            # prediction
            outputs = model(images)
            # calculate the loss
            loss = criterion(outputs, masks)
            # back propagation
            loss.backward()
            # update weights with SGD
            optimizer.step()
            # add to sum loss
            train_loss += loss.item()

            # calculate the IOU
            # no gredient here
            with torch.no_grad():
                # get the predicted label
                pred = outputs.argmax(dim = 1)
                # valid pixels only
                valid = (masks != 255) & (masks != 6)
                for one_class in particle_class:
                  # valid prediction for one label
                  pred_class = (pred == one_class) & valid
                  # valid real mask for one label
                  mask_class = (masks == one_class) & valid
                  # intersection and union
                  inter = (pred_class & mask_class).sum().item()
                  union = (pred_class | mask_class).sum().item()
                  train_iou[one_class]['inter'] += inter
                  train_iou[one_class]['union'] += union

            # show the batch loss
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        # get the epoch loss
        train_loss /= len(train_loader)
        # get the epoch IOU
        mean_iou = []
        for one_class in particle_class:
            union = train_iou[one_class]['union']

            # make sure union is above 0
            if union >0:
                iou = train_iou[one_class]['inter'] / train_iou[one_class]['union']
                mean_iou.append(iou)
        mean_iou = sum(mean_iou) / len(mean_iou) if len(mean_iou) > 0 else 0.0
        train_iou = mean_iou

        # Validation section
        model.eval()
        # the loss of the validation set
        val_loss = 0.0
        # same metrics
        val_iou = {stage: {'inter': 0, 'union': 0} for stage in particle_class}

        # the mIOU_all
        val_iou_all = {stage: {'inter': 0, 'union': 0} for stage in all_class}

        # also the dice metrics
        val_dice = {stage: {'tp': 0, 'fp': 0, 'fn': 0} for stage in all_class}

        # also the dice particle
        val_dice_particle = {stage: {'tp': 0, 'fp': 0, 'fn': 0} for stage in particle_class}

        # with no gradient
        with torch.inference_mode():
            # same bar
            for batch in tqdm(val_loader, desc = f'Epoch {epoch + 1} / {epochs} [Val]'):
                images = batch['image'].to(device)
                masks = batch['mask'].to(device)
                outputs = model(images)
                # calculate the loss
                loss = criterion(outputs, masks)
                # add to the validation sum loss
                val_loss += loss.item()

                # same process to get the prediction
                pred = outputs.argmax(dim = 1)
                # for valid pixels only
                valid = (masks != 255) & (masks != 6)
                # for each class
                for one_class in particle_class:
                  # get the valid prediction for this class
                  pred_class = (pred == one_class) & valid
                  # valid ground truth
                  mask_class = (masks == one_class) & valid
                  # inter and union
                  inter = (pred_class & mask_class).sum().item()
                  union = (pred_class | mask_class).sum().item()
                  val_iou[one_class]['inter'] += inter
                  val_iou[one_class]['union'] += union

                  # here we also calculate the dice
                  tp_particle = inter
                  fp_particle = (pred_class & (~mask_class)).sum().item()
                  fn_particle = ((~pred_class) & mask_class).sum().item()
                  val_dice_particle[one_class]['tp'] += tp_particle
                  val_dice_particle[one_class]['fp'] += fp_particle
                  val_dice_particle[one_class]['fn'] += fn_particle

                # then, the all class including background
                # copy mask
                mask2 = masks.clone()
                # combine the background label
                mask2[mask2 == 5] = 0

                # copy the result
                pred2 = pred.clone()
                # combine the background label
                pred2[pred2 == 5] = 0

                # the valid
                valid_all = (mask2 != 255) & (mask2 != 6)


                # calculate the mIOU_all
                for one_class in all_class:
                    # get the valid prediction
                    pred_class_all = (pred2 == one_class) & valid_all
                    # valid ground truth
                    mask_class_all = (mask2 == one_class) & valid_all
                    # get the inter
                    inter_all = (pred_class_all & mask_class_all).sum().item()
                    # get the union
                    union_all = (pred_class_all | mask_class_all).sum().item()
                    # record
                    val_iou_all[one_class]['inter'] += inter_all
                    val_iou_all[one_class]['union'] += union_all

                    # here we also calculate the dice
                    tp = inter_all
                    fp = (pred_class_all & (~mask_class_all)).sum().item()
                    fn = ((~pred_class_all) & mask_class_all).sum().item()
                    val_dice[one_class]['tp'] += tp
                    val_dice[one_class]['fp'] += fp
                    val_dice[one_class]['fn'] += fn

        # get the validation epoch loss
        val_loss /= len(val_loader)
        # get the vallidation epoch mIOU_particle
        mean_val_iou = []
        for one_class in particle_class:
            # get the union
            union = val_iou[one_class]['union']

            # make sure the union is above 0
            if union > 0:
                iou = val_iou[one_class]['inter'] / val_iou[one_class]['union']
                mean_val_iou.append(iou)
        mean_val_iou_number = sum(mean_val_iou) / len(mean_val_iou) if len(mean_val_iou) > 0 else 0.0
        val_iou = mean_val_iou_number

        # get the val dice
        mean_val_dice_particle = []
        for one_class in particle_class:
            tp_p = val_dice_particle[one_class]['tp']
            fp_p = val_dice_particle[one_class]['fp']
            fn_p = val_dice_particle[one_class]['fn']

            # get the denominator first
            denominator_p = 2 * tp_p + fp_p + fn_p
            # make sure the denominator is above 0
            if denominator_p > 0:
                dice_p = 2 * tp_p / (2 * tp_p + fp_p + fn_p)
                mean_val_dice_particle.append(dice_p)
        mean_val_dice_result = sum(mean_val_dice_particle) / len(mean_val_dice_particle) if len(mean_val_dice_particle) > 0 else 0.0
        val_dice_p = mean_val_dice_result

        # get the mIOU_all
        mean_val_iou_all = []
        for one_class in all_class:
            # get the union
            union_all = val_iou_all[one_class]['union']

            # make sure the union is above 0
            if union_all > 0:
               iou_all = val_iou_all[one_class]['inter'] / val_iou_all[one_class]['union']
               mean_val_iou_all.append(iou_all)
        mean_val_iou_all_number = sum(mean_val_iou_all) / len(mean_val_iou_all) if len(mean_val_iou_all) > 0 else 0.0
        val_iou_all = mean_val_iou_all_number

        # get the val dice
        mean_val_dice = []
        for one_class in all_class:
            tp = val_dice[one_class]['tp']
            fp = val_dice[one_class]['fp']
            fn = val_dice[one_class]['fn']

            # get the denominator first
            denominator = 2 * tp + fp + fn
            # make sure the denominator is above 0
            if denominator > 0:
                dice = 2 * tp / (2 * tp + fp + fn)
                mean_val_dice.append(dice)
        mean_val_dice_number = sum(mean_val_dice) / len(mean_val_dice) if len(mean_val_dice) > 0 else 0.0
        val_dice = mean_val_dice_number

        # the cosine annealing update
        scheduler.step()

        # print summary
        print(f'\n Epoch {epoch + 1}:')
        # training set info
        print(f'Train Loss: {train_loss:.4f}, Train IoU_particle: {train_iou:.4f}')
        # validation set info
        print(f'Val Loss: {val_loss:.4f}')
        print(f'Val IoU_particle: {val_iou:.4f}, Val Dice_particle: {val_dice_p:.4f}')
        print(f'Val IoU_all: {val_iou_all:.4f}, Val Dice_all: {val_dice:.4f}')
        print()

        # save the weights of the best model (and keep updating)
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            torch.save(model.state_dict(), save_path)
            print('Weights for best model saved')
            print()

    return model

In [None]:
# training  section
def unet_stage():

    # load the trainloader and valloader
    train_loader, val_loader = create_dataloaders(stages = [1,2,3,4], patch_size = 224, batch_size = 24,
                          num_workers = 2, val_split = 0.2)
    # for a single batch
    for batch in train_loader:
      print(f"batch shape: {batch['image'].shape}")
      # print the statistic
      stats = calculate_batch_statistics(batch)
      print(f"Patches with cells: {stats['patches_with_cells']} / {len(batch['image'])}")
      print(f"Cell pixel ratio: {stats['cell_ratio'] * 100:.2f}%")
      break

    # construct the model
    model = transunet(img_size = (224, 224))

    # device checking
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"device: {device}")

    # training
    model = train_model(model, train_loader, val_loader, epochs = 30, device = device, lr = 1e-4,
             dice_weight = 0.6, ce_weight = 0.4, save_path = "/content/drive/MyDrive/Final Project/Image-segmentation-Level-4/TransUnet_seg_weights.pth")

if __name__  == "__main__":
  unet_stage()
  print("Training finished")

In [None]:
# release the memory
import gc
gc.collect()

### Particle picking stage

The purpose of this stage is to implement particle picking study based on the above outcome, the saved weights

All result will be saved with visualization form

In [None]:
# import the package
import os
import numpy as np
import torch
import torch.nn.functional as F

# normalize first
def normalize_data(data):
    """
    mrc normalization

    argument:
    data: the input data
    """
    # use the quantile normalization
    # get the 1st percentile and 99th percentile
    p1, p99 = np.percentile(data, [1, 99])
    # clip the data
    data = np.clip(data, p1, p99)
    # then map the data to [0, 1]
    data = (data - p1) / (p99 - p1 + 1e-8) # add a 1e-8 to prevent 0
    return data

def particle_picking(input_file, models, patch_size = 256, stride = 128, device = 'cuda'):
  """
  pick the particle of input files

  argument:
  file: the input file
  models: the used model
  patch_size: the size of patch in sliding window sampling
  stride: the stride of sliding window sampling
  device: cuda or cpu
  """

  # make sure 2D file
  assert input_file.ndim == 2, "Input file should be 2D"

  # normalize the target file
  norm_file = normalize_data(input_file)

  # the 224 patch size is not divisible, we need padding
  def need_padding(origin_image, patch_size, stride):
      """
      add padding to the image for slding window sampling

      argument:
      origin_image: the input image
      patch_size: the size of patch in sliding window sampling
      stride: the stride of sliding window sampling
      """
      # original size
      h, w = origin_image.shape
      # calculate the required length
      # Needed length = stride * n + patch, n to Z
      need_h = ((h - patch_size) // stride + 1) * stride + patch_size
      need_w = ((w - patch_size) // stride + 1) * stride + patch_size
      # calculate padding length
      pad_h = need_h - h
      pad_w = need_w - w
      # pad with black, like the black border
      new_image = np.pad(origin_image, ((0, pad_h), (0, pad_w)), mode = 'constant', constant_values = 0)
      return new_image, pad_h, pad_w

  # pad this micrograph
  norm_file, pad_h, pad_w = need_padding(norm_file, patch_size, stride)

  # get the file shape
  h, w = norm_file.shape

  # define the global info
  classes = 7
  logit_glob = torch.zeros(classes, h, w, device = device, dtype = torch.float32)
  count_glob = torch.zeros(h, w, device = device, dtype = torch.float32)

  with torch.no_grad():
      for y in range(0, h - patch_size + 1, stride):
          for x in range(0, w - patch_size + 1, stride):
              # get the patch
              patch = norm_file[y:y + patch_size, x:x + patch_size]
              # convert to tensor
              # normalization
              # convert to float
              patch  = patch.astype(np.float32)

              # stack up to rgb
              patch = np.stack([patch] * 3, axis = 2) # (H, W, 3)

              # use Imagenet normalization
              mean = np.array([0.485, 0.456, 0.406])
              std = np.array([0.229, 0.224, 0.225])
              patch = (patch - mean) / std

              # change the (H, W, 3) to (3, H, W)
              patch = patch.transpose(2, 0, 1)

              # convert to tensor
              # now the patch size is (h, w), we want [batch_size, n_channels, h, w]
              patch_tensor = torch.from_numpy(patch[None]).float().to(device)

              # fit for logit [C,H,W]
              result = models(patch_tensor)[0]
              # get probs
              probs = F.softmax(result, dim = 0)
              # add to the globe
              logit_glob[:, y:y + patch_size, x:x + patch_size] += probs
              # add to the count
              count_glob[y:y + patch_size, x:x + patch_size] += 1.0

  # make sure the count is no less than 1
  count_glob = torch.clamp(count_glob, min = 1.0)
  # get the mean logit, [C,H,W] / [1,H,W]
  avg_glob = logit_glob / count_glob.unsqueeze(0)

  # get the prediction label
  pred_label = torch.argmax(avg_glob, dim = 0)

  # cut the padding
  if pad_h > 0 or pad_w > 0:
      pred_label = pred_label[:-pad_h, :-pad_w]

  # get the boolean from foreground mask
  pred1 = (pred_label == 1)

  pred2 = (pred_label == 2)

  pred3 = (pred_label == 3)

  pred4 = (pred_label == 4)

  return pred1.cpu().numpy(), pred2.cpu().numpy(), pred3.cpu().numpy(), pred4.cpu().numpy(), pred_label.cpu().numpy()

In [None]:
import random
import matplotlib.pyplot as plt
# implement the model
# specify the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

weight_path = '/content/drive/MyDrive/Final Project/Image-segmentation-Level-4/TransUnet_seg_weights.pth'

# create TransUnet
model = transunet(img_size = (224, 224)).to(device)
# load existed weights
model.load_state_dict(torch.load(weight_path, map_location = device))
# eval mode
model.eval()

root_path = Path('/content/drive/MyDrive/Final Project/Dataset-processed')
mask_base = Path('/content/drive/MyDrive/Final Project/Image-segmentation-Level-4')

# requied mapping
stages = {'stageI': 1, 'stageII': 2, 'stageIII': 3, 'stageIV': 4}

# define the path to save result
out_dir = mask_base / 'TransUnet_result_on_MRC'
os.makedirs(out_dir, exist_ok = True)

# random seed
random.seed(2025)

# extract files
for stage in stages.keys():
    print(f"Now extracting from {stage}")
    # connect to get the stage folder path
    stage_path = root_path / stage
    # find tif files
    all_files = [f for f in os.listdir(stage_path) if f.endswith('.mrc')]
    # print choose files
    all_files = random.sample(all_files, min(100, len(all_files)))
    print(f"chosen file: {all_files}")
    print()

    # make stage folder
    stage_dir = out_dir / stage
    os.makedirs(stage_dir, exist_ok = True)

    # for one file
    for one_file in all_files:
        results = {}
        # get the file path
        file_path = stage_path / one_file
        # read the file
        with mrcfile.open(file_path) as mrc:
          # convert to float
          file_data = mrc.data.astype(np.float32)

        # fit the mrc file into model
        pred1, pred2, pred3, pred4, pred_mask = particle_picking(file_data, model, patch_size = 224, stride = 112, device = device)

        # the percentage of multi-stage mask
        for single_stage, corresponding_mask in enumerate([pred1, pred2, pred3, pred4], start = 1):
            results[single_stage] = float(corresponding_mask.mean() * 100.0)

        # draw the plot
        norm_origin = normalize_data(file_data.squeeze())
        # to [H,W,3]
        rgb_origin = np.stack([norm_origin] * 3, axis = -1)
        # the mark overlay
        overlay = rgb_origin.copy()
        # four colors
        # stage 1 use blue
        overlay[pred1] = [0.0, 0.0, 1.0]
        # stage 2 use red
        overlay[pred2] = [1.0, 0.0, 0.0]
        # stage 3 use yellow
        overlay[pred3] = [1.0, 1.0, 0.0]
        # stage 4 use pink
        overlay[pred4] = [1.0, 0.0, 1.0]
        # mix
        integrated = 0.6 * rgb_origin + 0.4 * overlay

        # file name
        name = file_path.stem

        # save file
        save_path = stage_dir / f"{name}_modelpred.png"
        plt.imsave(save_path.as_posix(), integrated)

        # print statistic
        print(f"File name: {file_path}")
        for the_stage_number, the_percentage in results.items():
            print(f"{the_stage_number} percentage: {the_percentage:.2f}%")
        print(f"Saved to {save_path}")
        print()