In [1]:
import argparse
import os
import random
import shutil
import time
import warnings
import pickle
import numpy as np
import math
import sys
import copy
import re
import pandas as pd
import matplotlib.pyplot as plt
import json
import cv2
from itertools import compress

import torch
import torch.nn as nn
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor,DefaultTrainer,HookBase
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer,ColorMode,GenericMask
from detectron2.structures import BoxMode
from detectron2.evaluation import COCOEvaluator,inference_on_dataset
from detectron2.data import build_detection_test_loader,DatasetMapper,build_detection_train_loader,MetadataCatalog,DatasetCatalog
import detectron2.data.transforms as T
import detectron2.utils.comm as comm

import ray
import time

import uuid as uuid
from operator import itemgetter
import seaborn as sns

import shapely
import shapely.geometry
from shapely.geometry import Polygon,MultiPolygon,GeometryCollection
from shapely.validation import make_valid
from shapely.geometry import mapping
#import geopandas as gpd

#import imgfileutils as imf
#import segmentation_tools as sgt
from aicsimageio import AICSImage, imread
from skimage import measure, segmentation
from skimage.measure import regionprops
from skimage.color import label2rgb
#import progressbar
from IPython.display import display, HTML
#from MightyMosaic import MightyMosaic

import glob
from PIL import Image
import csv

In [2]:
# setup directory
root = r'/Users/lovely_shufan/'

project_dir = root + r'Dropbox (Edison_Lab@UGA)/AMF/AMF Imaging 2022/0_inference_using_MaskRCNN_2021/'
output_dir = project_dir + r'2_infer_result/GA_GWAS_2022/'

model_dir = root + r'Dropbox (Edison_Lab@UGA)/AMF/AMF Imaging 2021/2_computer_vision/'

data_dir = r'/Volumes/easystore/GWAS 2022/'

blocks = ['Block2/','Block3/','Block8/']

## Model Inference Configuration

In [3]:
classes=['root','AMF internal hypha','AMF external hypha','AMF arbuscule','AMF vesicle','AMF spore','others']

In [4]:
cfg = get_cfg() # return default configuration
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) # copy config files from open source projects

# training configuration
cfg.DATASETS.TEST=()
cfg.DATALOADER.NUM_WORKERS=2
#cfg.SOLVER.IMS_PER_BATCH=args.batch_size

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE=128 #Number of regions per image used to train RPN. faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES=len(classes)# (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
cfg.MODEL.BACKBONE.FREEZE_AT=2
cfg.SEED=1
cfg.AUG_FLAG=1

# inference configuration
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8  # set threshold for this model
cfg.MODEL.WEIGHTS=os.path.join(model_dir, "Trainset1_model_best.pth") # path to the best model trained
cfg.MODEL.DEVICE='cpu' # use cpu for inference

# I removed spore and others from inference classes
inf_metadata = MetadataCatalog.get("inference").set(thing_classes=['root','AMF internal hypha','AMF external hypha','AMF arbuscule','AMF vesicle'])

In [5]:
predictor = DefaultPredictor(cfg)

[32m[10/02 22:14:18 d2.checkpoint.detection_checkpoint]: [0m[DetectionCheckpointer] Loading from /Users/lovely_shufan/Dropbox (Edison_Lab@UGA)/AMF/AMF Imaging 2021/2_computer_vision/Trainset1_model_best.pth ...


The checkpoint state_dict contains keys that are not used by the model:
  [35mpixel_mean[0m
  [35mpixel_std[0m


## Ray remote funtion

In [6]:
def centering2train(diff, img, x, y):
    '''
    :param diff: 3d tensor (row, col, 3)
    :param img: np.ndarray (row, col ,3)
    :return img:
    :rtype ndnumpy.array:
    Objective: output an image centered using training set per-channel means
    '''
    
    # Convert diff to a numpy array and reshape it to (3, 1, 1)
    diff_array = np.array(diff).reshape(3, 1, 1)
    # Broadcast to shape (3, imageHeight, imgWidth)
    expanded_array = np.broadcast_to(diff_array, (3, y, x))
    # Transpose to get shape (imageHeight, imgWidth, 3)
    diff_ts = expanded_array.transpose(1, 2, 0)
    
    img = np.add(img,diff_ts)
    # Clip values outside the interval are clipped to the interval edges
    np.clip(img, 0, 255, out=img)
    return img

def padImg(img, x, y, tilex, tiley):
    '''
    :param img:
    :return padded img:
    :rtype ndnumpy.array:
    Objective: output a padded image dividle by tile size
    '''
    pad_top = tiley - (y % tiley)
    pad_lft = tilex - (x % tilex)
    img = cv2.copyMakeBorder(img,pad_top,0,pad_lft,0,cv2.BORDER_CONSTANT,value=[0,0,0])
    return img

In [7]:
@ray.remote
def inference(pathtofile, block, diff, predictor):
    blklist = []
    imgidlist = []
    sceneidlist = []
    tileidlist = []
    classlist = []
    confscorelist=[]
    arealist = []

    # read in czi
    czi = AICSImage(pathtofile)
   
    for scene in czi.scenes:
        # extract image by scene
        czi.set_scene(scene)
        img = czi.get_image_data("YXS", T=0,C=0,Z=0) # numpy.ndarray  
        y = img.shape[0]
        x = img.shape[1]
        # centering
        img = centering2train(diff, img, x, y)
        # pad image
        img = padImg(img, x, y, 2560, 1920)
        # tiling
        for i in range(0,y,1920):
            for j in range(0,x,2560):
                xmin = j
                xmax = j + 2560
                ymin = i
                ymax = i + 1920
                tile_id = str(xmin)+"_"+str(ymin)+"_"+str(xmax)+"_"+str(ymax)
                subimg = img[ymin:ymax,xmin:xmax]
                outputs = predictor(subimg)
                        
                #inference outputs
                clasind = outputs['instances'].get('pred_classes')
                allmasks = outputs['instances'].get('pred_masks')
                allscores = outputs['instances'].get('scores')
            
                num_seg = clasind.size()[0]
                if num_seg != 0: # only save an entry when the image contains a segmentation
                    blklist = blklist + np.repeat(block[:-1], num_seg).tolist()
                    imgidlist = imgidlist + np.repeat(pathtofile, num_seg).tolist()
                    sceneidlist = sceneidlist + np.repeat(scene, num_seg).tolist()
                    tileidlist = tileidlist + np.repeat(tile_id, num_seg).tolist()
                    confscorelist = confscorelist + allscores.tolist()
                
                    # calculate the area of segmentation
                    v = Visualizer(subimg[:, :, ::-1], MetadataCatalog.get("inference"), scale=1.0)
                    for i in range(0,num_seg,1):
                        #calculate mask area
                        locmask = np.asarray(allmasks[i,:,:])
                        gmask = GenericMask(locmask,v.output.height,v.output.width)
                        if gmask.polygons:
                            mergpolygon = gmask.polygons[0]
                            all_points_x = mergpolygon[::2]
                            all_points_y = mergpolygon[1::2]
                            pgon = Polygon(zip(all_points_x,all_points_y))
                            arealist.append(pgon.area)
                            # class index to class name
                            classlist.append(classes[clasind.tolist()[i]])
                        else: # assign NAs to non-polygon mask
                            arealist.append(math.nan)
                            classlist.append(math.nan)

    # export inference result as df
    infresults = pd.DataFrame({
    'block': blklist,
    'filename': imgidlist,
    'scene': sceneidlist,
    'tile': tileidlist,
    'annotations': classlist,
    'area': arealist,
    'confidenceScore': confscorelist})
    
    # delete 
    return infresults

## Prepare for parallel inference

In [8]:
allpath2block2img = [os.path.join(path,name) for path, dirs, files in os.walk(os.path.join(data_dir,'Block2/'))
                        for name in files
                        if name.endswith('.czi')]
allpath2block8img = [os.path.join(path,name) for path, dirs, files in os.walk(os.path.join(data_dir,'Block8/'))
                        for name in files
                        if name.endswith('.czi')]
allpath2block10img = [os.path.join(path,name) for path, dirs, files in os.walk(os.path.join(data_dir,'Block10/'))
                        for name in files
                        if name.endswith('.czi')]

In [10]:
# Get file sizes for each file
file_sizes = [(path, os.path.getsize(path)) for path in allpath2block2img]
# Sort the file paths based on file sizes
sorted_path2block2img = [item[0] for item in sorted(file_sizes, key=lambda x: x[1])]

file_sizes = [(path, os.path.getsize(path)) for path in allpath2block8img]
sorted_path2block8img = [item[0] for item in sorted(file_sizes, key=lambda x: x[1])]

file_sizes = [(path, os.path.getsize(path)) for path in allpath2block10img]
sorted_path2block10img = [item[0] for item in sorted(file_sizes, key=lambda x: x[1])]


In [11]:
print(len(sorted_path2block2img))
print(len(sorted_path2block8img))
print(len(sorted_path2block10img))

344
384
337


In [23]:
def findBGRmean(file):
    czi = AICSImage(file)
    scenes = czi.scenes
    x = random.choice(scenes)
    czi.set_scene(x)
    img_array = czi.get_image_data("YXS", T=0,C=0,Z=0)
    # Convert numpy ndarray to PyTorch tensor
    img_tensor = torch.from_numpy(img_array).float()  # Convert data type to float
    # Change the dimensions to (channels, height, width)
    img_tensor = img_tensor.permute(2, 0, 1)
    # Calculate the mean for each channel
    mean_values = img_tensor.mean(dim=[1, 2])
    return mean_values.unsqueeze(0)

In [42]:
@ray.remote
def rfindBGRmean(file):
    czi = AICSImage(file)
    scenes = czi.scenes
    x = random.choice(scenes)
    czi.set_scene(x)
    img_array = czi.get_image_data("YXS", T=0,C=0,Z=0)
    # Convert numpy ndarray to PyTorch tensor
    img_tensor = torch.from_numpy(img_array).float()  # Convert data type to float
    # Change the dimensions to (channels, height, width)
    img_tensor = img_tensor.permute(2, 0, 1)
    # Calculate the mean for each channel
    mean_values = img_tensor.mean(dim=[1, 2])
    return mean_values.unsqueeze(0)

In [13]:
testczi = AICSImage(sorted_path2block2img[343])

In [24]:
means = findBGRmean(sorted_path2block2img[343])

In [25]:
print(means.shape)

torch.Size([1, 3])


In [21]:
# randomly choose 30 scenes from block2 images
sample30 = random.sample(allpath2block2img,30)
file_sizes = [(path, os.path.getsize(path)) for path in sample30]
sorted_sample30 = [item[0] for item in sorted(file_sizes, key=lambda x: x[1])]

In [44]:
ray.init(num_cpus=18, ignore_reinit_error=True)

2023-10-02 23:50:01,149	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.8.18
Ray version:,2.6.3
Dashboard:,http://127.0.0.1:8265


In [45]:
id30 = [rfindBGRmean.remote(czi) for czi in sample30]
results = []
t0 = time.time()
for i in range(1,40):
    ready, not_ready = ray.wait(id30, num_returns = 1)
    print('iteration:', i) 
    results.extend(ray.get(ready))
    del ready
    id30 = not_ready
    if not id30: 
        break
print('Time Elapsed:\t{:.4f}'.format(time.time() - t0))

iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
iteration: 11
iteration: 12
iteration: 13
iteration: 14
iteration: 15
iteration: 16
iteration: 17
iteration: 18
iteration: 19
iteration: 20
iteration: 21
iteration: 22
iteration: 23
iteration: 24
iteration: 25
iteration: 26
iteration: 27
iteration: 28
iteration: 29
iteration: 30
Time Elapsed:	34.5561


In [43]:
ray.shutdown()

In [46]:
print(results)

[tensor([[174.9772, 174.4386, 173.9380]]), tensor([[135.9655, 135.5739, 135.5250]]), tensor([[162.2106, 158.3190, 157.7910]]), tensor([[149.1607, 147.1816, 146.7908]]), tensor([[170.2700, 168.7100, 168.2848]]), tensor([[152.6491, 150.6038, 150.2349]]), tensor([[116.8860, 115.9650, 114.6433]]), tensor([[132.3646, 132.6676, 132.8183]]), tensor([[126.8691, 125.2042, 124.5668]]), tensor([[143.1965, 141.5301, 141.2795]]), tensor([[124.6719, 123.9459, 123.7732]]), tensor([[149.2669, 147.0962, 146.4940]]), tensor([[126.3323, 123.8550, 123.2347]]), tensor([[137.9844, 137.6405, 137.1638]]), tensor([[148.0428, 147.6507, 146.1341]]), tensor([[124.5397, 123.0701, 122.6528]]), tensor([[128.9459, 126.8098, 126.2509]]), tensor([[104.0251, 103.3973, 103.3155]]), tensor([[119.4886, 116.3476, 115.5772]]), tensor([[102.5644,  99.8095,  99.2338]]), tensor([[107.3763, 105.3407, 104.7592]]), tensor([[128.4014, 127.4890, 127.3115]]), tensor([[162.4699, 160.4639, 159.9668]]), tensor([[142.3185, 139.3402, 138.

In [47]:
results_ts = torch.cat(results, dim=0)
block2BGRmeans = results_ts.mean(dim=0)

In [49]:
print(block2BGRmeans)

tensor([135.0869, 133.4352, 132.9082])
