# Automatic segmentation of precipitate statistics using U-Net architecture

This code utilizes a U-Net deep learning architecture to automatically extract precipitate statistics from transmission electron microscopy (TEM) images. 
The code is developed as part of a master's thesis in applied physics, the code segments precipitates within the images, enabling the automatic measurement of precipitate length and cross-sections. 
By automating this process, it significantly accelerates the analysis of precipitate distributions, aiding in materials research and development.

## Author:

**Espen J. Gregory** - Developed for Master thesis in Physics 2024

## Note:
- It is recommended to have a GPU and the CUDA-version of Pytorch installed (However it is not required).
- Make sure model files (.pth) are placed in the same folder as the notebook
- Data can be loaded in two ways, either by directly uploading the .DM3 file, or converting the .DM3 to an image (.jpeg/.png) and manually selecting the calibration unit *nm_per_px*.
- U-Net documentation: https://arxiv.org/abs/1505.04597


### Imports and PyTorch initializationk

In [18]:
%matplotlib qt5
%load_ext autoreload
%autoreload 2
%pip install pandas

import cv2
import time
import torch
import numpy as np
import pandas as pd
import tkinter as tk
import _dm3_lib as dm
from pathlib import Path
from PIL import Image
from itertools import product
from tkinter import filedialog
from u_net_pytorch import UNet
from skimage import measure, color, io
from skimage.segmentation import clear_border
from pathlib import Path

"""PyTorch Initialization"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = True
torch.manual_seed(0)
torch.cuda.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device type: %s"%(device))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Note: you may need to restart the kernel to use updated packages.
Device type: cpu


In [21]:
# Using DataSetEvaluator
from testMaster.DatasetEvaluator import UNETEvaluator

# Example usage:
this_dir = Path.cwd()
dataset_path = this_dir.parent / "data" / "test_cross"
model_path = this_dir.parent / "data" / "models" / "cross_unet.pth"


unet_evaluator = UNETEvaluator(
    dataset_dir = dataset_path,
    model = model_path,
    cross = True,
    device = 'cpu'
)

unet_evaluator.statistics()



  self.tile_size = 512


 Unet Model Loaded


## Function definitions

In [7]:
def tile_img(arr, d) -> list:
    
    """
    Tile the image into equal parts of size (d x d) pixels

    img: PIL Image
    d  : Side length of square tile

    Return: List of PIL images
    """
    img = Image.fromarray(arr)
    w, h = img.size
    grid = product(range(0, h-h%d, d), range(0, w-w%d, d))
    img_list = []
    for i, j in grid:
        box = (j, i, j+d, i+d)
        img_list.append(img.crop(box))
    return img_list


def DM_2_array(img) -> np.array:
    """
    Convert Digital Micrograph file to numpy array

    img: An instance of the DM3 class from _dm3_lib.py

    returns a numpy array of the grayscale image
    """
    nm_per_px = img.pxsize[0]
    cons = img.contrastlimits
    im   = img.imagedata
    im[im>cons[1]] = cons[1]
    im[im<cons[0]] = cons[0]
    im =  ((im-cons[0])/(cons[1]-cons[0]))*255  #0 to 1
    return im.astype(np.uint8), nm_per_px

def load_data() -> list:
    
    """
    Opens dialogbox that allows selection of files
    Returns file/files
    """
    root = tk.Tk()
    root.withdraw()
    root.call('wm', 'attributes', '.', '-topmost', True)
    files = filedialog.askopenfilenames(parent=root, title='Choose a file')
    return files


class Prediction():
    def __init__(self,cross):
        self.cross     = cross
        self.size      = 1024
        self.tile_size = 512
        if self.cross:
            self.PATH = r'C:\Users\krist\Documents\masterRepo\data\models\cross_unet.pth'
        else:
            self.PATH = r".\models\length_unet.pth"

        self.checkpoint = torch.load(self.PATH, map_location=torch.device('cpu'))
        self.model      = UNet(in_channels = 1, n_classes = 2, depth = 3, wf = 6, padding = True)
        self.model.load_state_dict(self.checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()
        print('Model Loaded')
    
    def to_tensor(self, file) -> list:
        """
        
        Opens the image in grayscale, resizes it (if applicable), and converts it to a pytorch tensor
    
        file   (str)    : Path to file
        Returns tensor or list of tensors
        
        """
        try:
            if file.endswith('.dm3'):
                
                image, nm_per_px = DM_2_array(dm.DM3(file))
                if len(self.images) == 0:
                    self.nm_per_px = nm_per_px
            else:
                image = np.array(Image.open(file).convert('L'))
        except Exception:
            raise ValueError("Something went wrong when loading the image.")
            pass
            
        image   = cv2.resize(image, dsize = (self.size,self.size))
        tensors = []
        images  = tile_img(np.array(image), self.tile_size)

        for i in images:
            
            im = np.expand_dims(i, axis=0)
            
            ## Can add filter if images are very noisy (Median recommended, gaussian makes the images too blurry)
            # image = nd.median_filter(image, size=3) 
            im = 2*(im/np.max(im)) - 1
            im = torch.tensor(im, dtype = torch.float32)
            tensors.append(im)
        return np.array(image), tensors


    
    def watershed(self,img, plot = False):
        
        """
        Performs the watershed algorithm on the prediction img
    
        img  : PIL.Image (semantic segmentation prediction map)
        plot : bool (True if you want to see the watershed processing steps)
        
        Documentation: https://docs.opencv.org/4.x/d3/db4/tutorial_py_watershed.html
        
        """
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = clear_border(gray)
        ret, bin_img = cv2.threshold(gray, 0, 255, cv2.THRESH_OTSU) 
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
        sure_bg = cv2.dilate(bin_img, kernel, iterations=20) 
        dist = cv2.distanceTransform(bin_img, cv2.DIST_L2, 5) 
        
       
        #foreground area 
        ret, sure_fg = cv2.threshold(dist, 0.15 * dist.max(), 255, cv2.THRESH_BINARY) 
        sure_fg = sure_fg.astype(np.uint8)   
          
        # unknown area 
        unknown = cv2.subtract(sure_bg, sure_fg) 
        ret, markers = cv2.connectedComponents(sure_fg) 
          
        # Add one to all labels so that background is not 0, but 1 
        markers += 1
        markers[unknown == 255] = 0
        markers = cv2.watershed(img, markers) 
        
        if plot:
            fig, axes = plt.subplots(2,2)
            axes[0,0].imshow(gray) 
            axes[0, 0].set_title('Img') 
            axes[0,1].imshow(dist) 
            axes[0, 1].set_title('Distance Transform') 
              
            axes[1,0].imshow(sure_fg) 
            axes[1, 0].set_title('Sure Foreground') 
            axes[1,1].imshow(markers) 
            axes[1, 1].set_title('Markers') 
    
        img2 = color.label2rgb(markers,bg_label = 1,bg_color=(0, 0, 0))
        props = measure.regionprops_table(markers, intensity_image=gray, 
                                      properties=['label',
                                                  'area', 'equivalent_diameter',
                                                  'mean_intensity', 'solidity'])
        
        df = pd.DataFrame(props)
        area = list(df[(df.mean_intensity > 100) & (df.area > 1.5/self.nm_per_px**2)].area)
        return area


    
    def calc_length(self, img):
        """
        Estimates the length of precipitates
        """
        grey = img[:,:,0]
        contours, hierarchy = cv2.findContours(grey, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 
        l      = []
        for contour in contours:
            (center), (width,height), angle = cv2.minAreaRect(contour)
            length = np.max([width,height])
            l.append([length, angle+(angle<0)*90])
        for index, (length, angle) in enumerate(l):
            median_angle = np.median([angle for (length, angle) in l if length*self.nm_per_px > 5]) #Find angles of all detections longer than 5nm
            error        = 5.0 #degrees
            if  (median_angle - error<angle<median_angle + error) and length*self.nm_per_px>3: #If precipitate in correct direction (within error) and longer than 3nm, accept detection
                self.lengths.append(length) 
    def evaluate(self, nm_per_px):
        """
        Evaluation/prediction function

        nm_per_px: float (Image calibration AS IF IMAGE IS 2048x2048)
        
        """
        self.data       = load_data()
        self.nm_per_px  = nm_per_px
        self.prediction = []
        self.images     = []
        self.area       = []
        self.lengths    = []

        start_time = time.time()
        for img in iter(self.data):
            true_img, imgs = self.to_tensor(img)
            new_im = Image.new('RGB', (self.size, self.size))
            if len(self.images) == 0:
                self.nm_per_px *=2 #Original calibration for 2048x2048, but images are resized to 1024x1024
            for index, im in enumerate(imgs):
                im = im.unsqueeze(0).to(device)
                with torch.no_grad():
                    pred   = self.model(im)
                    output = torch.argmax(pred, dim=1)  # Get the index of the channel with the highest probability
                    output = output.squeeze(0).cpu().numpy()
                    
                y_offset =  int(self.tile_size*(index>1))
                x_offset =  int(self.tile_size*((index)%2))
                out      = Image.fromarray(output.astype('uint8')*255).convert('RGB')
                new_im.paste(out, box = (x_offset,y_offset))
                
            self.images.append(np.array(true_img))
            self.prediction.append(np.array(new_im))

            if self.cross:
                self.area += self.watershed(self.prediction[-1], plot = False)
            else:
                self.calc_length(self.prediction[-1])
                
        
        total_time = time.time()-start_time
        print(f"Total interference time: {np.round(total_time,2)}s ; Time per Image: {np.round(total_time/len(self.images),3)}s")

        if self.cross:
            return np.array(self.area)*self.nm_per_px**2, self.images, self.prediction
        else:
            return np.array(self.lengths)*self.nm_per_px, self.images, self.prediction        

## Cross-section

Note: nm_per_px (Calibration) should be the calibration for a 2048x2048 image, if the images are .dm3, the manual calibration is not needed. 

The original images as prediction masks are found in (images, prediction)

In [10]:
cross_sections = Prediction(cross = True)
area, images, prediction = cross_sections.evaluate(nm_per_px = 0.069661)
print(len(prediction))

print('Average: {0:.2f}nm, STDev: {1:.2f}nm, Number counted: {2:d}'.format(np.mean(area), np.std(area), len(area)))

  self.checkpoint = torch.load(self.PATH, map_location=torch.device('cpu'))


Model Loaded
Total interference time: 33.06s ; Time per Image: 33.06s
1024
Average: 8.44nm, STDev: 4.85nm, Number counted: 30


In [16]:
import matplotlib.pyplot as plt
plt.imshow(images[0])
plt.show()

TypeError: Invalid shape (1, 1024, 1024) for image data

## Length

Note: nm_per_px (Calibration) should be the calibration for a 2048x2048 image, if the images are .dm3, the manual calibration is not needed. 

The original images as prediction masks are found in (images, prediction)

In [15]:
length = Prediction(cross = False)
length, images, prediction = length.evaluate(nm_per_px = 0.16685)

print('Average: {0:.2f}nm, STDev: {1:.2f}nm, Number counted: {2:d}'.format(np.mean(length), np.std(length), len(length)))

  self.checkpoint = torch.load(self.PATH, map_location=torch.device('cpu'))


Model Loaded
Total interference time: 25.26s ; Time per Image: 25.261s
Average: 7.01nm, STDev: 1.41nm, Number counted: 3


In [16]:
plt.imshow(images[0])
plt.show()