# Automatic segmentation of precipitate statistics using Mask RCNN architecture

This code utilizes a Mask R-CNN deep learning architecture to automatically extract precipitate statistics from transmission electron microscopy (TEM) images. 
The code is developed as part of a master's thesis in applied physics, the code segments precipitates within the images, enabling the automatic measurement of precipitate length and cross-sections. 
By automating this process, it significantly accelerates the analysis of precipitate distributions, aiding in materials research and development.

## Author:

**Espen J. Gregory** - Developed for Master thesis in Physics 2024

## Note:

- It is recommended to have a GPU and the CUDA-version of Pytorch installed (However it is not required).
- Make sure model files (.pth) are placed in the same folder as the notebook
- Data can be loaded in two ways, either by directly uploading the .DM3 file, or converting the .DM3 to an image (.jpeg/.png) and manually selecting the calibration unit *nm_per_px*.
- Mask R-CNN documentation: https://arxiv.org/abs/1703.06870

### Imports/Dependencies and Pytorch initalization

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib qt5
import gc
import cv2
import time
import torch
import matplotlib
import numpy as np
import tkinter as tk
import _dm3_lib as dm
import torchvision as tv
import matplotlib.pyplot as plt
import os

from PIL import Image
from tkinter import filedialog
from matplotlib.widgets import Button, Slider
from skimage.segmentation import clear_border
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from pathlib import Path

"""PyTorch Initialization"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = True
torch.manual_seed(0)
torch.cuda.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device type: %s"%(device))


font = {'size'   : 18}
matplotlib.rc('font', **font)



Device type: cpu


In [19]:
# Using DataSetEvaluator
from testMaster.DatasetEvaluator import DatasetEvaluator, RCNNEvaluator

# Example usage:
image_dir = r'C:\Users\krist\Documents\masterRepo\data\test_cross'  

# Initialize the RCNN evaluator (set cross=True for cross-section evaluation)
rcnn_evaluator = RCNNEvaluator(path_model=Path(r'C:\Users\krist\Documents\masterRepo\data\models\cross_rcnn.pth'), cross=True, device='cpu')

# Initialize DatasetEvaluator with the path to the image folder
dataset_evaluator = DatasetEvaluator(image_dir)

# Run the predictions and generate plots for each image
dataset_evaluator.predict(rcnn_evaluator)


Mask-RCNN Model Loaded
Evaluating image 1/6: C:\Users\krist\Documents\masterRepo\data\test_cross\01-dm310h_jpg.rf.72eebbad25b572b6ff1245627e55c004.jpg
Evaluating image 2/6: C:\Users\krist\Documents\masterRepo\data\test_cross\012_jpg.rf.9614581b7c0a0232b0ec72d332a49471.jpg
Evaluating image 3/6: C:\Users\krist\Documents\masterRepo\data\test_cross\06-dm310h_jpg.rf.328259131328f76f78cf56bedf95173b.jpg
Evaluating image 4/6: C:\Users\krist\Documents\masterRepo\data\test_cross\1-dm316h_jpg.rf.0d5872cc83a194e1e4214845e7aa92e9.jpg
Evaluating image 5/6: C:\Users\krist\Documents\masterRepo\data\test_cross\2_jpg.rf.f09d9baa85d7b92179ecf126e40b7d72.jpg
Evaluating image 6/6: C:\Users\krist\Documents\masterRepo\data\test_cross\3-dm32h_jpg.rf.ce5a5c139142e9358a28858a8e6561fe.jpg


### Defining functions

In [5]:
def DM_2_array(img) -> np.array:
    """
    Convert Digital Micrograph file to numpy array

    img: An instance of the DM3 class from _dm3_lib.py

    returns a numpy array of the grayscale image
    """
    nm_per_px = img.pxsize[0]
    cons = img.contrastlimits
    im   = img.imagedata
    im[im>cons[1]] = cons[1]
    im[im<cons[0]] = cons[0]
    im =  ((im-cons[0])/(cons[1]-cons[0]))*255  #0 to 1
    return im.astype(np.uint8), nm_per_px
    
def load_data() -> list:
    """
    Opens dialogbox that allows selection of files
    Returns file/files
    """
    root = tk.Tk()
    root.withdraw()
    root.call('wm', 'attributes', '.', '-topmost', True)
    files = filedialog.askopenfilenames(parent=root, title='Choose a file')
    return files


def check_image(n: int, thresh: float, erode: int) -> (np.array, np.array):
    """
    Function used to check the mask overlay of image n, given a masking threshold thresh
    
    Returns: Image, Image with predicted mask overlay 
    """
    print(f"Image checked: {model.data[n]}")
    pred = model.prediction[n]
    im   = model.images[n]
    gray    = cv2.cvtColor(im[0],cv2.COLOR_BGR2RGB)
    overlay = gray.copy()
    scr_thres = 0.9 #Confidence score threshold for the RPN (Region proposal network) 
    for i in range(len(pred[0]['masks'])):
        msk=pred[0]['masks'][i,0].detach().cpu().numpy()
        scr=pred[0]['scores'][i].detach().cpu().numpy()
        box = [int(i) for i in pred[0]['boxes'][i].detach().cpu().numpy()]
        if scr>scr_thres:

            mask    = msk>thresh
            kernel  = np.ones((2, 2), np.uint8) 
            mask_er = cv2.erode(mask.astype(np.float32), kernel, iterations = erode)  
            mask    = mask_er>0
            overlay[:,:,:][mask] =  [1,0,0] #Makes mask overlay red
    im   = 0
    pred = 0
    return gray, overlay


def update(erode, val):
    """
    Function that updates the matplotlib figure when the slider is moved.
    """
    global erode_it_temp, temp_threshold
    if erode:
        erode_it_temp = round(val)
    else:
        temp_threshold = val
    gray, overlay = check_image(n, temp_threshold ,erode_it_temp)
    ax[1].imshow(overlay)
    fig.canvas.draw_idle()
def Check_Mask(n, model):
    
    global button, thresh_slider,thresh_slider2, threshold, temp_threshold, fig, ax, erode_it, erode_it_temp
    temp_threshold = model.threshold
    erode_it_temp = model.erode_it

    accept = False
    fig, ax = plt.subplots(1,2,figsize = (20,10))
    gray, overlay = check_image(n, model.threshold ,model.erode_it)
    ax[0].imshow(gray)
    ax[0].axis('off')
    ax[1].imshow(overlay)
    ax[1].axis('off')
    axthresh  = fig.add_axes([0.125, 0.06, 0.775, 0.03])
    axthresh2 = fig.add_axes([0.125, 0.1, 0.775, 0.03])
    thresh_slider2 = Slider(ax=axthresh2, label='Mask threshold', valmin=0, valmax=1,valinit=model.threshold)
    thresh_slider  = Slider(ax=axthresh, label='Erode iterations', valmin=0, valmax=10,valinit=model.erode_it,valfmt="%i")
    def accept(event):
        model.threshold = temp_threshold
        model.erode_it  = erode_it_temp
        plt.close()
    thresh_slider.on_changed(lambda x: update(True,x))
    thresh_slider2.on_changed(lambda x: update(False,x))
    resetax = fig.add_axes([0.8, 0.01, 0.1, 0.04])
    button = Button(resetax, 'Accept', hovercolor='0.975')
    button.on_clicked(accept)


class Prediction():
    def __init__(self,cross):
        self.cross     = cross
        self.size      = 1024
        
        
        if self.cross:
            self.threshold = 0.9
            self.erode_it  = 0
            self.PATH = r"C:\Users\krist\Documents\masterRepo\data\models\cross_rcnn.pth"
        else:
            self.erode_it  = 4
            self.threshold = 0.5
            self.PATH = r".\models\length_rcnn.pth"

        self.checkpoint = torch.load(self.PATH, map_location=torch.device('cpu'))
        self.model      = tv.models.detection.maskrcnn_resnet50_fpn(weights='DEFAULT', min_size=1024, max_size=2048, box_detections_per_img = 500) 
        in_features     = self.model.roi_heads.box_predictor.cls_score.in_features 
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes=2)
        self.model.load_state_dict(self.checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()
        print('Model Loaded')
    
    def to_tensor(self, file) -> torch.Tensor:
        """
        Opens the image in grayscale, resizes it (if applicable), and converts it to a pytorch tensor
    
        file   (str)    : Path to file
            
        Returns tensor
        """
        try:
            if file.endswith('.dm3'):
                image, nm_per_px = DM_2_array(dm.DM3(file))
                if len(self.images) == 0:
                    self.nm_per_px = nm_per_px
            else:
                image = np.array(Image.open(file).convert('L'))
        except Exception:
            raise ValueError("Something went wrong when loading the image.")
            pass
        image = cv2.resize(image, dsize = (self.size,self.size))
        image = np.expand_dims(image, axis=0)
        
        ## Can add filter if images are very noisy (Median recommended, gaussian makes the images too blurry)
        # image = nd.median_filter(image, size=3) 
        image = image/np.max(image)
        image = torch.tensor(image, dtype = torch.float32)
        
        return image

    def evaluate(self, nm_per_px):
        """
        Prediction Mask R-CNN

        nm_per_px: float (Image calibration AS IF IMAGE IS 2048x2048)
        
        """
        self.data       = load_data()
        self.nm_per_px  = nm_per_px
        self.prediction = []
        self.images     = []
        self.lengths    = []

        start_time = time.time()
        for img in iter(self.data):
            im = self.to_tensor(img).unsqueeze(0).to(device)
            if len(self.images) == 0:
                self.nm_per_px *=2 #Original calibration for 2048x2048, but images are resized to 1024x1024
            with torch.no_grad(): #Predict
                pred = self.model(im)
                self.prediction.append(pred)
            im = im[0].detach().cpu().numpy()
            self.images.append(im)
        total_time = time.time()-start_time
        print(f"Total interference time: {np.round(total_time,2)}s ; Time per Image: {np.round(total_time/len(self.images),3)}s")
    def statistics(self):
        
        
        
        self.area    = []
        self.lengths = []
        print('Mask threshold set to {0:.2f}'.format(self.threshold))
        print('Calibration used: {0:.4f} nm/px'.format(self.nm_per_px))
        
        if self.cross:
            for pred in self.prediction:
                for i in range(len(pred[0]['masks'])):
                    box = pred[0]['boxes'][i].detach().cpu().numpy()
                    msk = pred[0]['masks'][i,0].detach().cpu().numpy()
                    scr = pred[0]['scores'][i].detach().cpu().numpy()
                    mask    = msk>self.threshold
                    kernel  = np.ones((2, 2), np.uint8) 
                    mask_er = cv2.erode(mask.astype(np.float32), kernel, iterations = self.erode_it)  
                    msk    = mask_er>0
                    clear = clear_border(msk)
                    area1 = np.sum(clear)
                    area2 = np.sum(msk)
                    if scr>0.9 and area1 == area2:
                        self.area.append(area1)
            return np.array(self.area)*self.nm_per_px**2
        else:
            for pred in self.prediction:
                for i in range(len(pred[0]['masks'])):
                    scr = pred[0]['scores'][i].detach().cpu().numpy()
                    box = pred[0]['boxes'][i].detach().cpu().numpy()
                    msk = pred[0]['masks'][i,0].detach().cpu().numpy()
                    mask    = msk>self.threshold
                    kernel  = np.ones((2, 2), np.uint8) 
                    mask_er = cv2.erode(mask.astype(np.float32), kernel, iterations = self.erode_it)  
                    msk     = clear_border(mask_er>0,buffer_size=10)
                    if scr>0.9 and np.any(msk):
                        rect = cv2.minAreaRect(np.argwhere((msk>self.threshold)))
                        (center), (width,height), angle = rect
                        length = np.max([width,height])
                        self.lengths.append(length)
            return np.array(self.lengths)*self.nm_per_px
            
            
        
        

# Cross-section

**Note:** 
- nm_per_px (Calibration) should be the calibration for a 2048x2048 image, if the images are .dm3, the manual calibration is not needed.
- If program runs slow, restart the kernel

In [6]:
model = Prediction(cross = True)
model.evaluate(nm_per_px = 0.069661) 

  self.checkpoint = torch.load(self.PATH, map_location=torch.device('cpu'))


Model Loaded
Total interference time: 20.89s ; Time per Image: 20.89s


#### Use Check_Mask to adjust threshold, and erosion iterations

**Note**
- Default values should be good
- Lowering erosion and mask threshold makes mask bigger
- n : Index of image you want to check

In [7]:
n = 0

Check_Mask(n, model)

Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg


Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/masterRepo/data/train_cross/3-dm32h_jpg.rf.3359921eb081c5b47c063d42cdfd8363.jpg
Image checked: C:/Users/krist/Documents/

In [5]:
area = model.statistics()

print('Average: {0:.2f}nm, STDev: {1:.2f}nm, Number counted: {2:d}'.format(np.mean(area), np.std(area), len(area)))

Mask threshold set to 0.90
Calibration used: 0.1393 nm/px
Average: 10.21nm, STDev: 3.45nm, Number counted: 28


**Clear memory**

In [6]:
model = None
gc.collect()
torch.cuda.empty_cache()

# Length

In [7]:
model = Prediction(cross = False)
model.evaluate(nm_per_px = 0.20835) 

  self.checkpoint = torch.load(self.PATH, map_location=torch.device('cpu'))


Model Loaded
Total interference time: 14.83s ; Time per Image: 14.826s


In [8]:
n = 0
Check_Mask(n,model)

Image checked: C:/Users/krist/OneDrive/Dokumenter/masterProsjekt/training/training_data/valid/3-dm32h_jpg.rf.ce5a5c139142e9358a28858a8e6561fe.jpg


In [9]:
l = model.statistics()
print('Average: {0:.2f}nm, STDev: {1:.2f}nm, Number counted: {2:d}'.format(np.mean(l), np.std(l), len(l)))

Mask threshold set to 0.50
Calibration used: 0.4167 nm/px
Average: 12.67nm, STDev: 3.20nm, Number counted: 3


In [10]:
model = None
gc.collect()
torch.cuda.empty_cache()