## Import packages

In [None]:
#pip uninstall numpy>=1.19.5

In [None]:
#import numpy as np
#print(np.__version__)

In [None]:
#pip install typing-extensions>=4.8.0

In [None]:
#! python --version

## The code starts from here

In [None]:
pip install segmenteverygrain

In [None]:
pip install pandas

In [None]:
pip install scikit-learn

In [None]:
pip install rtree

In [None]:
pip install torchvision

In [None]:
pip install PyQt5

In [None]:
pip install geopandas

## Load models

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from skimage.measure import regionprops, regionprops_table
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import load_img
from importlib import reload
import segmenteverygrain as seg
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from tqdm import trange
%matplotlib qt

In [2]:
model = seg.Unet()
model.compile(optimizer=Adam(), loss=seg.weighted_crossentropy, metrics=["accuracy"])
# you need to download the 'segmenteverygrain' model for this to work ()
model.load_weights('./checkpoints/seg_model');

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="./checkpoints/sam_vit_h_4b8939.pth")

  state_dict = torch.load(f)


In [None]:
#fname = './AL2_H1_r.tif'

In [None]:
#big_im = np.array(load_img(fname))
#big_im_pred = seg.predict_image(big_im, model, I=256)
# decreasing the 'dbs_max_dist' parameter results in more SAM prompts (and longer processing times):
#labels, grains = seg.label_grains(big_im, big_im_pred, dbs_max_dist=10.0)
#all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(sam, big_im, big_im_pred, labels, min_area=50.0)

In [3]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None # needed if working with very large images
fname = "./SM1_H1.2.jpg"

In [None]:
all_grains = seg.predict_large_image(fname, model, sam, min_area=1000.0, patch_size=2000, overlap=200)

segmenting image tiles...


100%|██████████| 6/6 [00:07<00:00,  1.23s/it]
100%|██████████| 5/5 [00:05<00:00,  1.19s/it]


creating masks using SAM...


 81%|████████  | 648/804 [01:08<00:17,  8.92it/s]

In [None]:
# plot results
image = np.array(load_img(fname))
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(image)
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired')
plt.axis('equal')
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)
* press the 'g' key to hide the grain masks (so that you can see the original image better); press the 'g' key again to show the grain masks

In [None]:
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', 
                              lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax))
cid2 = fig.canvas.mpl_connect('key_press_event', 
                              lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax))

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [None]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

Use this function to update the 'all_grains' list after deleting and merging grains:

In [None]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, image)

Plot the updated set of grains:

In [None]:
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(image)
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired')
seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [None]:
predictor = SamPredictor(sam)
predictor.set_image(image) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick(event, ax, coords, image, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress(event, ax, fig))

After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [None]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

In [None]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, image)

In [None]:
all_grains

## Get grain size distribution

Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [None]:
cid5 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.click_for_scale(event, ax))

Use the length of the scale bar in pixels (it should be printed above) to get the scale of the image (in units / pixel):

In [None]:
n_of_units = 300 # centimeters in the case of 'IMG_5208_image.png' 
units_per_pixel = n_of_units/1194.81 # length of scale bar in pixels

In [None]:
props = regionprops_table(labels.astype('int'), intensity_image = image, properties =\
        ('label', 'area', 'centroid', 'major_axis_length', 'minor_axis_length', 
         'orientation', 'perimeter', 'max_intensity', 'mean_intensity', 'min_intensity'))
grain_data = pd.DataFrame(props)
grain_data['major_axis_length'] = grain_data['major_axis_length'].values*units_per_pixel
grain_data['minor_axis_length'] = grain_data['minor_axis_length'].values*units_per_pixel
grain_data['perimeter'] = grain_data['perimeter'].values*units_per_pixel
grain_data['area'] = grain_data['area'].values*units_per_pixel**2

In [None]:
props

In [None]:
grain_data.head()

In [None]:
len(grain_data)

In [None]:
plt.figure()
plt.hist(grain_data['major_axis_length'], 25)
plt.xlabel('major axis length (cm)')
plt.ylabel('count');

## Save mask and image to PNG files

In [None]:
#dirname = 'C:/Users/ana.torresferreira/Desktop/Segmenteverygrain/Outputs/ML2_H2'
# write grayscale mask to PNG file
#cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Save the image as a PNG file
#cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

## Convert polygon rows and columns coordinates to projected coordinates and save them to shapefile

In [None]:
fname1 = "./SM1_H1.2_r.tif"

In [None]:
import rasterio
dataset = rasterio.open(fname1)

In [None]:
dataset.meta

In [None]:
# convert polygon coordinates from row, col to UTM
from shapely.geometry import Polygon
projected_polys = []
for grain in all_grains:
    x, y = rasterio.transform.xy(dataset.transform, grain.exterior.xy[1], grain.exterior.xy[0])
    poly = Polygon(np.vstack((x, y)).T)
    projected_polys.append(poly)

In [None]:
# create georeferenced pandas dataframe
import geopandas
gdf = geopandas.GeoDataFrame(projected_polys, columns = ['geometry'])
gdf.head(5)

In [None]:
# create property dataframe from labeled image
props = regionprops_table(labels.astype('int'), intensity_image = image, properties =\
        ('label', 'area', 'centroid', 'major_axis_length', 'minor_axis_length'))
grain_data = pd.DataFrame(props)
grain_data['major_axis_length'] = grain_data['major_axis_length'].values
grain_data['minor_axis_length'] = grain_data['minor_axis_length'].values
grain_data['area'] = grain_data['area'].values
grain_data.head()

In [None]:
# Ensure centroid_x and centroid_y are the same length as gdf
centroid_x, centroid_y = rasterio.transform.xy(dataset.transform, grain_data['centroid-0'], 
                                               grain_data['centroid-1'])

In [None]:
# Check if lengths match between gdf and centroid arrays
if len(centroid_x) != len(gdf):
    print(f"Length of centroids (x: {len(centroid_x)}, y: {len(centroid_y)}) does not match gdf length ({len(gdf)})")
    
    # Truncate gdf to match the length of centroids if gdf is longer
    if len(gdf) > len(centroid_x):
        gdf = gdf.iloc[:len(centroid_x)]
    # Or truncate the centroid arrays if they are longer than gdf
    else:
        centroid_x = centroid_x[:len(gdf)]
        centroid_y = centroid_y[:len(gdf)]

# Assign the adjusted centroid_x and centroid_y to the GeoDataFrame
gdf['centroid_x'] = centroid_x
gdf['centroid_y'] = centroid_y

# Check the output to ensure the lengths now match
print(f"Updated gdf length: {len(gdf)}, Centroid lengths: x = {len(centroid_x)}, y = {len(centroid_y)}")


In [None]:
# convert centroids from row, col to UTM and add them to geodataframe
centroid_x, centroid_y = rasterio.transform.xy(dataset.transform, grain_data['centroid-0'], 
                                               grain_data['centroid-1'])
gdf['centroid_x'] = centroid_x
gdf['centroid_y'] = centroid_y

In [None]:
# convert grain axis lengths to UTM units
gdf['major_axis_length'] = grain_data['major_axis_length'] * dataset.transform[0]
gdf['minor_axis_length'] = grain_data['minor_axis_length'] * dataset.transform[0]
gdf.head()

In [None]:
# check if everything looks good
band1 = dataset.read(1)
band2 = dataset.read(2)
band3 = dataset.read(3)
plt.figure()
plt.imshow(np.stack((band1, band2, band3), axis=2), extent = [dataset.bounds[0], dataset.bounds[2], 
                                         dataset.bounds[1], dataset.bounds[3]])
plt.scatter(gdf['centroid_x'], gdf['centroid_y']);

In [None]:
gdf.crs = dataset.crs # set geodataframe CRS

In [None]:
gdf.to_file('C:/Users/ana.torresferreira/Desktop/Segmenteverygrain/Outputs/SM1_H1.2/SM1_H1.2.shp')

In [None]:
#pip install openpyxl

In [None]:
gdf.to_excel('C:/Users/ana.torresferreira/Desktop/Segmenteverygrain/Outputs/SM1_H1.2/SM1_H1.2.xlsx')

In [None]:
dataset.close()