# Overview
Use this notebook to convert an OCT image you have to an H&E image in order to evaluate how the code works.

To get started,
[open this notebook in colab](https://colab.research.google.com/github/WinetraubLab/zero_shot_segmentation/blob/main/zero_shot_segmentation_oct.ipynb)
 and run.


In [None]:
# Path to an OCT image to convert
oct_input_image_path = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2020-11-10 10x Raw Data Used In Paper (Paper V2)/LG-19 - Slide04_Section02 (Fig 3.c)/OCTAligned.tiff"

#how many microns per pixel for each axis
microns_per_pixel_z = 1
microns_per_pixel_x = 1

In [None]:
#sam
using_colab = True
visualize_sam_outputs = True
inject_real_histology_and_segmentation = True


sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

Assumptions:

oct scan x/z rates:
*   microns per pixel z = 1
*   microns per pixel x = 1

pix2pix input sizes:
*   virtual histology input width = 256
*   virtual histology input height = 256

pix2pix input x/z rates:
*   microns per pixel z = 1
*   microns per pixel x = 2

In [None]:
#pix2pix input sizes
VIRTUAL_HIST_WIDTH = 256
VIRTUAL_HIST_HEIGHT = 256
#verify input sizes
MICRONS_PER_PIXEL_Z_TARGET = 2
MICRONS_PER_PIXEL_X_TARGET = 4

#get roboflow input

In [None]:
visualize_oct2hist_outputs = False

FIG_SIZE = (10,5)

#installing pip requirements and git repos:

In [None]:
init_dir = %pwd
!pip install roboflow


from IPython.display import clear_output
# Clone repository
!git clone --recurse-submodules https://github.com/WinetraubLab/OCT2Hist-UseModel

base_folder = "/content/rf_dir/OCT2Hist-UseModel/pytorch-CycleGAN-and-pix2pix"

# Install dependencies
!pip install -r {base_folder}/requirements.txt
# Clean up this window once install is complete
clear_output()

%cd init_dir

# inputs

In [None]:
import torch
import torchvision

# oct2hist setup

In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive/')

# This is the folder that the pre-trained model is in
model_folder = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2020-11-10 10x Model (Paper V2)"

# Copy model to this folder over
!mkdir {base_folder}/checkpoints
!mkdir {base_folder}/checkpoints/pix2pix/
!cp "{model_folder}/latest_net_G.pth" {base_folder}/checkpoints/pix2pix/
!cp "{model_folder}/latest_net_D.pth" {base_folder}/checkpoints/pix2pix/

Preprocess

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

# Load OCT image
oct_image_orig = cv2.imread(oct_input_image_path)
oct_image_orig = cv2.cvtColor(oct_image_orig, cv2.COLOR_BGR2RGB)

oct_image = oct_image_orig.copy()
# Show Images to user
fig, axes = plt.subplots(1, 2, figsize=FIG_SIZE)
oct_image_orig_shape = oct_image.shape
axes[0].imshow(oct_image)
axes[0].axis("off")
axes[0].set_title(f"Original OCT image ({oct_image_orig_shape})")



In [None]:
%cd /content/rf_dir/OCT2Hist-UseModel
from utils.masking_utils import mask_image
preprocessed_img, filt_img = mask_image(oct_image)

In [None]:
from utils.img_utils import showImg
showImg(preprocessed_img)

crop

In [None]:
from utils.img_utils import showImg
#slice from image
width = 256 * 4
height = 256 * 2
x0 = 135
z0= 350
cropped = preprocessed_img[z0:z0+height, x0:x0+width]
showImg(cropped)
resized = cv2.resize(cropped, [VIRTUAL_HIST_WIDTH,VIRTUAL_HIST_HEIGHT] , interpolation=cv2.INTER_AREA)
o2h_input = resized

install dependencies required to read the image

read it and verify it fits the input requirements.

#run oct2hist

In [None]:
# Create a folder and place OCT image
!mkdir {base_folder}/dataset
!mkdir {base_folder}/dataset/test/

# Before writting image to file, check size
if o2h_input.shape[:2] != (256, 256):
        raise ValueError("Image size must be 256x256 pixels to run model on.")

# Padd image and write it to the correct place
padded = np.zeros([256,512,3], np.uint8)
padded[:,:256,:] = o2h_input[:,:,:]
cv2.imwrite(f"{base_folder}/dataset/test/im1.jpg", padded)

In [None]:
# This is the folder that the pre-trained model is in
model_folder = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2020-11-10 10x Model (Paper V2)"

# Copy model to this folder over
!mkdir {base_folder}/checkpoints
!mkdir {base_folder}/checkpoints/pix2pix/
!cp "{model_folder}/latest_net_G.pth" {base_folder}/checkpoints/pix2pix/
!cp "{model_folder}/latest_net_D.pth" {base_folder}/checkpoints/pix2pix/

In [None]:
!python {base_folder}/test.py --netG resnet_9blocks --dataroot "{base_folder}/dataset/"  --model pix2pix --name pix2pix --checkpoints_dir "{base_folder}/checkpoints" --results_dir "{base_folder}/results"

#Optional: visualize output

In [None]:
histology_image = cv2.imread(f"{base_folder}/results/pix2pix/test_latest/images/im1_fake_B.png")
histology_image = cv2.cvtColor(histology_image, cv2.COLOR_BGR2RGB)

height,width = cropped.shape[:2]
histology_image_resized = cv2.resize(histology_image, [width,height] , interpolation=cv2.INTER_AREA)
visualize_oct2hist_outputs = True
if visualize_oct2hist_outputs:
  # present side by side
  fig, axes = plt.subplots(1, 2, figsize=FIG_SIZE)
  axes[0].imshow(cropped)
  axes[0].axis("off")
  axes[0].set_title("OCT")
  axes[1].imshow(histology_image_resized)
  axes[1].axis("off")
  axes[1].set_title("Virtual Histology")
  plt.show()

In [None]:
#inject ground truth histology
histology_input_image_path = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2020-11-10 10x Raw Data Used In Paper (Paper V2)/LG-19 - Slide04_Section02 (Fig 3.c)/HistologyAligned.tiff"


histology_image = cv2.imread(histology_input_image_path)
histology_image = cv2.cvtColor(histology_image, cv2.COLOR_BGR2RGB)
cropped_histology = histology_image[z0:z0+height, x0:x0+width]

height,width = cropped.shape[:2]
histology_image_resized = cv2.resize(cropped_histology, [width,height] , interpolation=cv2.INTER_AREA)
visualize_oct2hist_outputs = True
if visualize_oct2hist_outputs:
  # present side by side
  fig, axes = plt.subplots(1, 2, figsize=FIG_SIZE)
  axes[0].imshow(cropped)
  axes[0].axis("off")
  axes[0].set_title("OCT")
  axes[1].imshow(histology_image_resized)
  axes[1].axis("off")
  axes[1].set_title("Virtual Histology")
  plt.show()


#run sam on virtual histology

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
from zero_shot_utils.utils import init_sam
mask_generator =  init_sam(model_type,sam_checkpoint )


In [None]:
from zero_shot_utils.utils import get_roboflow_data
!mkdir rf_dir
get_roboflow_data()


In [None]:
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

!mkdir images
!wget -P images https://pbs.twimg.com/media/FvpQj7UWYAAgxfo?format=jpg&name=large
#https://twitter.com/JMGardnerMD/status/1655724394805706752/photo/1
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

## Set-up

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
def show_anns(anns):

  if len(anns) == 0:
    return
  sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
  ax = plt.gca()
  ax.set_autoscale_on(False)

  img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
  img[:,:,3] = 0
  for ann in sorted_anns:
      m = ann['segmentation']
      color_mask = np.concatenate([np.random.random(3), [0.35]])
      img[m] = color_mask
  ax.imshow(img)

## Example image

In [None]:
%matplotlib notebook
%matplotlib inline

## Automatic mask generation

To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended.

To generate masks, just run `generate` on an image.

In [None]:
masks = mask_generator.generate(histology_image_resized)

In [None]:
if visualize_sam_outputs:
  plt.figure(figsize=FIG_SIZE)
  plt.imshow(histology_image_resized)
  show_anns(masks)
  plt.axis('off')
  plt.show()

Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format

Show all the masks overlayed on the image.

## Automatic mask generation options

There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:[link text](https://)


Inject ground truth: the true h&e image:

In [None]:
from zero_shot_utils.utils import sam_masking
masks2 = sam_masking(inject_real_histology_and_segmentation = True)

In [None]:
if visualize_sam_outputs:
  plt.figure(figsize=FIG_SIZE)
  plt.imshow(histology_image_resized)
  show_anns(masks2)
  plt.axis('off')
  plt.show()

In [None]:
if visualize_sam_outputs:
  

#project on oct

In [None]:
plt.figure(figsize=FIG_SIZE)
plt.imshow(cropped)
show_anns(masks2)
plt.axis('off')
plt.show()

# Eval results

In [None]:
from zero_shot_utils.utils import score_masking
# Replace with the path to your segmentation mask file
segmentation_mask_path = f"/content/rf_dir/Zero-shot-oct-1/train/Hist-1_png.rf.168d6c48c79ccf974a2a1ecac761d3f5_mask.png"
# Load the segmentation mask image using OpenCV
segmentation_mask = cv2.imread(segmentation_mask_path, cv2.IMREAD_UNCHANGED)
print(score_masking(masks2, segmentation_mask))
