### Import Packages

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from skimage.measure import regionprops, regionprops_table
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import load_img
from importlib import reload
import segmenteverygrain as seg
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from tqdm import trange
import urllib.request
%matplotlib qt

### Enhance training images with Adaptive Equalization

In [24]:
from skimage import exposure
import cv2
from glob import glob

def perform_adaptive_equalization(img_path, clip_lim=0.01):
    img = cv2.imread(img_path)

    # Adaptive Equalization
    img_adapteq = exposure.equalize_adapthist(img, clip_limit=clip_lim)

    return(img_adapteq)

# image_dir = 'images/ara-train/'
image_dir = 'images/ara-test/subset/'
output_dir = image_dir + 'enhanced/'
images = glob(image_dir + "*.JPG")

for image in images:
    output = perform_adaptive_equalization(image)
    img_name = image.split("\\")[1].split(".")[0] + '_enhanced' + '.jpg'
    cv2.imwrite(output_dir + img_name, 255*output)

### Download model checkpoint

In [7]:
import urllib.request
urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'sam_vit_h_4b8939.pth')

('sam_vit_h_4b8939.pth', <http.client.HTTPMessage at 0x1dc916b1450>)

### Load models

In [3]:
# %%time

model = seg.Unet()
model.compile(optimizer=Adam(), loss=seg.weighted_crossentropy, metrics=["accuracy"])
model.load_weights('./checkpoints/seg_model_20231009').expect_partial()

sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")

### Check for GPU

In [7]:
import torch

if torch.cuda.is_available():
    sam.to(device='cuda')
    print("cuda enabled")
else:
    sam.to(device='cpu')
    print("cpu only")

cuda enabled


In [5]:
import torch

torch.cuda.is_available()

True

In [9]:
from tensorflow.python.client import device_lib

def get_available_devices():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos]

print(get_available_devices())

['/device:CPU:0']


### Run segmentation

In [33]:
# reload(seg)

# fname = 'images/ara-test/subset/enhanced/0ap9oe_enhanced.jpg'
# fname = 'adapteq_0smyr1.jpg'

big_im = np.array(load_img(fname))
big_im_pred = seg.predict_big_image(big_im, model, I=256)

labels, grains, coords = seg.label_grains(big_im, big_im_pred, dbs_max_dist=10.0)
# all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(sam, big_im, big_im_pred, coords, labels, min_area=50.0)
# _, _, mask_all, _, _, _ = seg.sam_segmentation(sam, big_im, big_im_pred, coords, labels, min_area=50.0)

100%|██████████| 7/7 [00:05<00:00,  1.22it/s]
100%|██████████| 6/6 [00:05<00:00,  1.19it/s]


### QC distribution of SAM prompts

In [22]:
plt.figure()
plt.imshow(big_im_pred)
plt.scatter(coords[:,0], coords[:,1], c='k');

### Delete or merge grains in segmentation result
* click on the grain to remove and press 'x' key
* click on two grains to merge, and press the 'm'm key

In [124]:
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', 
                              lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax))
cid2 = fig.canvas.mpl_connect('key_press_event', 
                              lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax))

Run below cell once finished with editing

In [125]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

Update the 'all_grains' list after deleting and merging grains

In [126]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, big_im)

  0%|          | 0/168 [00:00<?, ?it/s]

100%|██████████| 168/168 [00:01<00:00, 115.98it/s]


Plot the updated set of grains:

In [121]:
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(big_im)
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(big_im, all_grains, ax, cmap='Paired')
# seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(big_im)[1]])
plt.ylim([np.shape(big_im)[0], 0]);

## Add new grains
* click on unsegmented grain that you want to add
* press the 'x' key to delete the last grain added
* press the 'm' key to merge the last 2 grains added
* right click outside the grain (but inside mask) to restrict the grain to a smaller mask

In [122]:
predictor = SamPredictor(sam)
predictor.set_image(big_im) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick(event, ax, coords, big_im, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress(event, ax, fig))

In [123]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

After finished deleting / adding grain masks, run below cell to generate updated set of grains:

In [115]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, big_im)

100%|██████████| 155/155 [00:01<00:00, 119.97it/s]


### Save mask and grain labels to PNG files

In [127]:
dirname = 'images/labeled/'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Define a colormap using matplotlib
num_classes = len(all_grains)
cmap = plt.get_cmap('viridis', num_classes)
# Map each class label to a unique color using the colormap
vis_mask = cmap(labels.astype(np.uint16))[:,:,:3] * 255
vis_mask = vis_mask.astype(np.uint8)
# Save the mask as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_labels.png', vis_mask)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(big_im, cv2.COLOR_BGR2RGB))

True

In [74]:
print(dirname + fname.split('/')[-1][:-4] + '_mask.png')

images/output/0dpocx_mask.png
