In [2]:
import torch
import tifffile as tiff
import numpy as np
import cv2
from pathlib import Path
import shutil
import sys
from PIL import Image
from detectron2.structures import BoxMode
from torch.utils.data import Dataset
from tqdm import tqdm
from matplotlib import image as mpimg
import matplotlib.pyplot as plt
import os
from natsort import natsorted
from IPython.display import clear_output
import time

In [3]:
class tiffSegmentation:
    def __init__(self, model_file, input_file, output_dir, output_num):
        self.model_file = model_file
        self.input_file = input_file
        self.output_dir = output_dir

        self.output_num = output_num
        
        self.img_dataset = cv2.imreadmulti(input_file)[1] 
        self.img = self.img_dataset[0]

        if output_num == None:
            self.output_num = len(self.img_dataset)
        else:
            output_num = output_num
        
        self.classes = ["amoeba", "yeast"]
        self.predictor = self.modelPrep()

    def phaseSeg(self):
        
        if os.path.exists(self.output_dir):  # if folder phase masks already exists in output directory remakes
            shutil.rmtree(self.output_dir)
        os.makedirs(self.output_dir)
        
        for i,img in enumerate(self.img_dataset[::2][:self.output_num]):
            self.img = img
            
            sys.stdout.write(f'\rSegmenting image {i + 1} / {len(self.img_dataset)}')
            sys.stdout.flush()
        
            image_filename = "phase_" + str(i)
            
            outputs = self.predictor(self.imgPrep())  # Format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
            # input_images_directory = Dataset

            class_masks = {class_name: torch.zeros_like(outputs["instances"].pred_masks[0], dtype=torch.int16, )
                           for class_name in self.classes}

        # Assign a unique integer label to each object in the mask
            for i, pred_class in enumerate(outputs["instances"].pred_classes):
                class_name = self.classes[pred_class]
                instance_mask = outputs["instances"].pred_masks[i]
                class_masks[class_name] = torch.where(instance_mask,
                                                      torch.tensor(i + 1, dtype=torch.float32),
                                                      class_masks[class_name].to(dtype=torch.float32))
                class_masks[class_name] = class_masks[class_name].to(dtype=torch.int16)
    
            for class_name, class_mask in class_masks.items():
                class_mask_np = class_mask.cpu().numpy()
                image_name = image_filename + f'_{class_name}.tif'
    
                output_path = os.path.join(self.output_dir, class_name)
                
                os.makedirs(output_path, exist_ok=True)
    
                Image.fromarray(class_mask_np.astype(np.uint16)).save(Path(output_path) / image_name)
    
    def modelPrep(self):
        from detectron2 import model_zoo
        from detectron2.engine import DefaultPredictor
        from detectron2.config import get_cfg
        
        
        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training
        cfg.MODEL.WEIGHTS = self.model_file
        cfg.SOLVER.IMS_PER_BATCH = 2  # This is the real "batch size" commonly known to deep learning people
        cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256  # Default is 512, using 256 for this dataset.
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2  # We have 200 classes.
        cfg.TEST.DETECTIONS_PER_IMAGE = 300

        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5# set a custom testing threshold
        predictor = DefaultPredictor(cfg)
        return predictor
    
    def imgPrep(self):
        im = np.stack([self.img,self.img,self.img])
        return im.transpose(1,2,0)

In [15]:
folder_path = Path(r"F:\Work\UNI\ResProj\TR-Y4-Project\Image_repo\no_yeat_1\TiffStack")
for file in folder_path.iterdir():
    if file.is_file():
        print(Path.joinpath(folder_path, file.name))
        


F:\Work\UNI\ResProj\TR-Y4-Project\Image_repo\no_yeat_1\TiffStack\no_yeat_1_MMStack_Pos0.ome.tif
F:\Work\UNI\ResProj\TR-Y4-Project\Image_repo\no_yeat_1\TiffStack\no_yeat_1_MMStack_Pos0_1.ome.tif
F:\Work\UNI\ResProj\TR-Y4-Project\Image_repo\no_yeat_1\TiffStack\no_yeat_1_MMStack_Pos0_2.ome.tif
F:\Work\UNI\ResProj\TR-Y4-Project\Image_repo\no_yeat_1\TiffStack\no_yeat_1_MMStack_Pos0_3.ome.tif


In [16]:
folder_path = Path(r"F:\Work\UNI\ResProj\TR-Y4-Project\Image_repo\no_yeat_1\TiffStack") # folder with tiffstacks
for i, file in enumerate(folder_path.iterdir()): # iterates through each tiff stack in folder
    if file.is_file():

        tiffstack_path = Path.joinpath(folder_path, file.name)
        model_path = r'F:\Work\UNI\ResProj\TR-Y4-Project\Model_v3\model_final.pth' #model used to segment images

        segDir = 'tiffPhaseSegs' + f'_NoYeast_{i}' #name of output directory

        # Name or version of your current output segmentation directoryu7
        output_path = r'F:\Work\UNI\ResProj\TR-Y4-Project\Research\SavedSegs' # location of masks folders 

        output_path = os.path.join(output_path, segDir)

        print(output_path)
    
        my_seg = tiffSegmentation(model_path, tiffstack_path, output_path, output_num=None) # instance of class used for tiff segementation

        my_seg.phaseSeg() #starts segmentation after intialisation

F:\Work\UNI\ResProj\TR-Y4-Project\Research\SavedSegs\tiffPhaseSegs_NoYeast_0
Segmenting image 255 / 510F:\Work\UNI\ResProj\TR-Y4-Project\Research\SavedSegs\tiffPhaseSegs_NoYeast_1
Segmenting image 255 / 510F:\Work\UNI\ResProj\TR-Y4-Project\Research\SavedSegs\tiffPhaseSegs_NoYeast_2
Segmenting image 255 / 510F:\Work\UNI\ResProj\TR-Y4-Project\Research\SavedSegs\tiffPhaseSegs_NoYeast_3
Segmenting image 135 / 270