## Notebook to find features.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pycroscopy/DTMicroscope/blob/utk/notebooks/4_stem_feature_finding_COLAB-Hackathon.ipynb)


## Server setup

In [None]:
!pip install -q pyro5
!pip install -q scifireaders
!pip install -q sidpy
!pip install -q pynsid
!pip install -q git+https://github.com/pycroscopy/DTMicroscope.git@utk
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install torch torchvision
!pip install matplotlib



In [None]:
!run_server

## Client side starts

In [1]:
import matplotlib.pylab as plt
import numpy as np
import Pyro5.api
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image
import requests
import os

### 1. connect to server

In [None]:
# Connect to the microscope server
uri = "PYRO:microscope.server@localhost:9091"
mic_server = Pyro5.api.Proxy(uri)




### 2. Download and Register dataset

#### 2a. download dataset

In [None]:
# download dataset
!gdown --id 16tqc8yqO5Vex6RHljBea3j_fV7Pkbhyq

#### 2b. register dataset in the DigitalTwin

In [None]:
# Initialize microscope and register data
mic_server.initialize_microscope("STEM")
mic_server.register_data("test_stem.h5")

# Get overview image
array_list, shape, dtype = mic_server.get_overview_image()
im_array = np.array(array_list, dtype=dtype).reshape(shape)

# Display the overview image
plt.imshow(im_array)
plt.axis("off")
plt.title("Overview Image")
plt.show()


## SAM model to find features: You'll need a gpu instance

In [None]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Specify the model type and checkpoint URL
model_type = "vit_b"  # Options: 'vit_b', 'vit_l', 'vit_h'
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
checkpoint_path = "sam_vit_b_01ec64.pth"

# Download the checkpoint if not already present
if not os.path.exists(checkpoint_path):
    print("Downloading SAM model checkpoint...")
    response = requests.get(checkpoint_url)
    with open(checkpoint_path, 'wb') as f:
        f.write(response.content)
    print("Download complete.")
else:
    print("SAM model checkpoint already exists.")


In [None]:

# Load the SAM model
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
# Initialize the automatic mask generator
mask_generator = SamAutomaticMaskGenerator(sam)


In [None]:
# Convert greyscale image to RGB by stacking it in 3 channels
rgb_image = np.stack((im_array,) * 3, axis=-1)
img_np = rgb_image

# Set a threshold to identify grey regions
threshold = 150  # Adjust this value based on your image
black = [0, 0, 0]

# Apply the threshold to replace grey background with black
mask = (img_np > threshold).all(axis=-1)
img_np[~mask] = black

# Convert back to PIL Image
img_pil_filtered = Image.fromarray(img_np)

# Display the result
plt.imshow(img_pil_filtered)
plt.axis('off')
plt.show()

In [None]:
img_np = np.array(img_pil_filtered)


In [None]:
print("Generating masks...")
masks = mask_generator.generate(img_np)
print(f"Number of masks generated: {len(masks)}")


In [None]:
import cv2
visual_image = img_np.copy()

# Iterate through each mask and overlay it with a unique color
for idx, mask in enumerate(masks, 1):  # Start counting from 1
    segmentation = mask['segmentation']
    # Generate a random color
    color = np.random.randint(0, 255, (3,), dtype=np.uint8)
    # Create a colored mask
    colored_mask = np.zeros_like(visual_image)
    colored_mask[segmentation] = color
    # Blend the colored mask with the original image
    visual_image = cv2.addWeighted(visual_image, 1.0, colored_mask, 0.5, 0)

# Compute and store centroids
centroids = []
for idx, mask in enumerate(masks, 1):
    segmentation = mask['segmentation']
    # Find the coordinates of the mask pixels
    coords = np.column_stack(np.where(segmentation))
    if coords.size == 0:
        continue  # Skip if mask is empty
    # Compute the centroid
    centroid = coords.mean(axis=0)
    centroids.append((centroid[1], centroid[0], idx))  # (x, y, label)

# Display the image with colored masks
plt.figure(figsize=(8,8))
plt.imshow(visual_image)
plt.axis('off')
plt.title('Image with Segmentation Masks')
plt.show()

# Overlay the labels on the image
plt.figure(figsize=(8,8))
plt.imshow(visual_image)
ax = plt.gca()

for (x, y, label) in centroids:
    # Choose a contrasting color for the text
    text_color = 'white' if np.mean(visual_image[int(y), int(x)]) < 128 else 'black'
    ax.text(x, y, str(label), color=text_color, fontsize=12,
            bbox=dict(facecolor='red' if text_color == 'white' else 'yellow', alpha=0.5, pad=1))

plt.axis('off')
plt.title('Image with Segmentation Masks and Labels')
plt.show()
