In [1]:
# imports
%matplotlib inline
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

### Functions

In [2]:
def read_images_to_array(folder_path):

  image_array = []
  # Get a sorted list of filenames
  filenames = sorted(os.listdir(folder_path))
  for filename in filenames:
    if filename.endswith(".jpg") or filename.endswith(".png"):
      img_path = os.path.join(folder_path, filename)
      img = cv2.imread(img_path)

      if img is not None:
        image_array.append(img)

  return image_array

def read_bin_files_to_array(folder_path):
    bin_files = []
    filenames = sorted(os.listdir(folder_path))
    for filename in filenames:
        if filename.endswith('.bin'):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'rb') as file:
                data = np.fromfile(file, dtype=np.float32)
                bin_files.append(data)

    return bin_files


def split_images(image_array): 

    red_region_images = []
    raw_images = [] 

    for image in image_array:
        if image[25,100].sum() == 255*3 :
            red_region_images.append(image)
        else: 
            raw_images.append(image) 
            
    return red_region_images, raw_images

def split_train_val_test(images, masks):

    train_images = []
    train_masks = []
    val_images = []
    val_masks = []
    test_images = []
    test_masks = []

    for i in range(len(images)): 

    # these numbers are made specifically for this dataset 
        
        if i < 27: 
            train_images.append(images[i])
            train_masks.append(masks[i])
        elif i < 32:
            val_images.append(images[i])
            val_masks.append(masks[i])
        else: 
            test_images.append(images[i])
            test_masks.append(masks[i])

    return train_images, train_masks, val_images, val_masks, test_images, test_masks

def crop_raw_images(image_array): 
    
    cropped_images = [] 
    
    for i in range(len(image_array)): 
        
        image = image_array[i]
        
        mask = np.zeros(image.shape, dtype=np.uint8)
        mask = cv2.circle(mask, (320, 240), 180, (255,255,255), -1)

        res = cv2.bitwise_and(image, mask)
        res[mask==0] = 255
        
        cropped_images.append(res)

    return cropped_images

def crop_masks(image_array):
    cropped_images = []

    for i in range(len(image_array)): 
        image = image_array[i]
        
        mask = np.zeros(image.shape, dtype=np.uint8)
        mask = cv2.circle(mask, (288, 307), 200, (255,255,255), -1)

        res = cv2.bitwise_and(image, mask)
        res[mask==0] = 255
        
        cropped_images.append(res)

    return cropped_images

def add_padding(image_array, amt_x, amt_y): 
    
    padded_images = []
    
    for image in image_array: 

        padded_image = cv2.copyMakeBorder(
            image,
            amt_y,
            amt_y,
            amt_x,
            amt_x,
            cv2.BORDER_CONSTANT,
            value=(255,255,255)
        )
        
        padded_images.append(padded_image)
        
    return padded_images

def zoom_at(image_array, zoom, coord=None):
    
    zoomed_array = []
    
    for img in image_array: 
        
        h, w, _ = [ zoom * i for i in img.shape ]

        if coord is None: cx, cy = w/2, h/2
        else: cx, cy = [ zoom*c for c in coord ]

        img = cv2.resize( img, (0, 0), fx=zoom, fy=zoom)
        img = img[ int(round(cy - h/zoom * .5)) : int(round(cy + h/zoom * .5)),
                   int(round(cx - w/zoom * .5)) : int(round(cx + w/zoom * .5)),
                   : ]
        zoomed_array.append(img)
    
    return zoomed_array


def create_binary_masks(image_array):
    binary_masks = []
    
    for image in image_array:
        # Ensure image is in BGR format (convert if necessary)
        if image.ndim == 2:
            # Convert grayscale to BGR color (assuming gray image)
            image_color = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        elif image.shape[2] != 3:
            raise ValueError("Input image must have 3 channels (BGR format).")
        else:
            image_color = image
        
        # Convert BGR to HSV
        hsv = cv2.cvtColor(image_color, cv2.COLOR_BGR2HSV)

        # Define lower and upper bounds for red color in HSV
        lower_red = np.array([0, 150, 115])
        upper_red = np.array([255, 255, 255])

        # Create mask using inRange function
        mask = cv2.inRange(hsv, lower_red, upper_red)

        # Apply bitwise AND operation using color image
        res = cv2.bitwise_and(image_color, image_color, mask=mask)
        
        binary_masks.append(mask)
        
    return binary_masks

def crop_images(image_array): 
    
    cropped_images = []
    
    for i in range(len(image_array) -1): 
        
        image = image_array[i]
        
        image_height, image_width = image.shape[:2]
        
        # Bounding box dimensions
        box_width, box_height = 256, 256

        x_top_left = (image_width - box_width) // 2
        y_top_left = (image_height - box_height) // 2
        x_bottom_right = x_top_left + box_width
        y_bottom_right = y_top_left + box_height
        
        cropped_image = image[y_top_left:y_bottom_right, x_top_left:x_bottom_right]
        cropped_images.append(cropped_image)
                              
    return cropped_images

In [None]:
def get_bounding_box(image_mask):
    
    if np.all(image_mask == 0):
        # If all zeros, create a random bounding box
        H, W = image_mask.shape
        x_min = np.random.randint(0, W)
        x_max = np.random.randint(x_min + 1, W + 1)  # Ensure x_max > x_min
        y_min = np.random.randint(0, H)
        y_max = np.random.randint(y_min + 1, H + 1)  # Ensure y_max > y_min
        
        bbox = [x_min, y_min, x_max, y_max]
    else: 
        if len(image_mask.shape) == 2 or image_mask.shape[2] == 1:
            gray = image_mask
        else:
            gray = cv2.cvtColor(image_mask, cv2.COLOR_BGR2GRAY)

        _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            return (0, 0, 0, 0)
        
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)
        
        bbox = [x, y, x+w, y+h]
    
    return bbox

In [None]:
# saving the files 
import matplotlib.pyplot as plt 
import numpy as np 
import cv2, os, random
from scipy.interpolate import griddata
from tqdm import tqdm

def read_bin(file_path): 
    with open(file_path, 'rb') as fid:
        data = np.fromfile(fid, dtype='>f8')
    
    points = data.reshape(-1, 3)

    #points[:, 0] -= np.median(points[:, 0])
    #points[:, 1] -= np.median(points[:, 1])
    #points[:, 2] -= np.median(points[:, 2])
    
    x = points[:, 0]
    y = points[:, 1]
    z = points[:, 2]

    grid_x, grid_y = np.meshgrid(
        np.linspace(min(x), max(x), 256),
        np.linspace(min(y), max(y), 256)
    )
    
    grid_z = griddata((x, y), z, (grid_x, grid_y), method='linear')

    return grid_x, grid_y, grid_z

def read_all_bins(folder_path):

     data_array = []
     filenames = sorted(os.listdir(folder_path))
     
     for filename in tqdm(filenames, desc="Reading Bin Files"):
          if filename.endswith(".bin"):
               file_path = os.path.join(folder_path, filename)
               x, y, z = read_bin(file_path)
               data_array.append((x, y, z, filename)) 
    
     return data_array

def save_contours(data_array, folder_path):
    os.makedirs(folder_path, exist_ok=True)

    for data in tqdm(data_array, desc="Saving Contour Plots"):
          x, y, z, original_filename = data
          base_file_name = os.path.splitext(original_filename)[0]  
          file_name = f"{base_file_name}.png"
          path = os.path.join(folder_path, file_name)

          plt.contourf(x, y, z, levels=100, cmap="gray")
          plt.gca().set_aspect('equal')
          plt.savefig(path)
          plt.close()


# test to see if read_all_bins works
data_array = read_all_bins('./data/inotive_data_bin_files')
#data_array = read_all_bins('/Users/riyajain/Desktop/reading-bin-files/bin files')
save_contours(data_array, './data/depth_images')

In [None]:
#adding depth into blue rgb channel - function + testing
import cv2
import numpy as np
import matplotlib.pyplot as plt

def infuse_depth_into_blue_channel(image_path, depth_map_path, output_path):
    # Load the original color image
    image = cv2.imread(image_path)
    
    # Load the grayscale depth map
    depth_map = cv2.imread(depth_map_path, cv2.IMREAD_GRAYSCALE)
    
    # Resize the depth map to match the image dimensions
    depth_map = cv2.resize(depth_map, (image.shape[1], image.shape[0]))
    
    # Split the image into RGB channels
    b, g, r = cv2.split(image)
    
    # Normalize depth map to match the blue channel (0-255)
    depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX)
    
    # Infuse the depth map into the blue channel
    infused_blue = cv2.addWeighted(b, 0.5, depth_map_normalized, 0.5, 0)
    
    # Merge the channels back
    infused_image = cv2.merge((infused_blue, g, r))
    
    # Save the resulting image
    cv2.imwrite(output_path, infused_image)

    # Display the result using matplotlib
    infused_image_rgb = cv2.cvtColor(infused_image, cv2.COLOR_BGR2RGB)
    plt.imshow(infused_image_rgb)
    plt.axis('off')
    plt.show()



# Test function with a sample image and depth map
image_path = './data/original_images/L3.3 Parental 2-Control-1-0-38_texture.jpg'
depth_map_path = './data/depth_images/L3.3 Parental 2-Control-1-0-38.png'
output_path = './SAM/data/infused_images/infused_image.png'

infuse_depth_into_blue_channel(image_path, depth_map_path, output_path)


In [None]:
#make original images all png
import os 
from glob import glob
import cv2

def convert_jpg_to_png(input_dir, output_dir):
    jpg_images = glob(os.path.join(input_dir, '*.jpg'))
    if not jpg_images:
        print(f"No jpg images found in directory: {input_dir}")
        return

    os.makedirs(output_dir, exist_ok=True)

    for jpg_image in jpg_images:
        # Load the jpg image
        image = cv2.imread(jpg_image)
        if image is None:
            print(f"Error: Image at path {jpg_image} could not be read.")
            continue

        base_name = os.path.basename(jpg_image)
        png_image_path = os.path.join(output_dir, os.path.splitext(base_name)[0] + '.png')

        cv2.imwrite(png_image_path, image)
        print(f"Converted {jpg_image} to {png_image_path}")

input_dir = './data/original_images'
output_dir = './data/original_images_png'

convert_jpg_to_png(input_dir, output_dir)


In [None]:
#saving the preprocessed images in a directory/folder
import os
from glob import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt

def infuse_depth_into_blue_channel(image_path, depth_map_path, output_path):
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Image at path {image_path} could not be read.")
        return
    
    depth_map = cv2.imread(depth_map_path, cv2.IMREAD_GRAYSCALE)
    if depth_map is None:
        print(f"Error: Depth map at path {depth_map_path} could not be read.")
        return
    
    # Resize the depth map to match the image dimensions
    depth_map = cv2.resize(depth_map, (image.shape[1], image.shape[0]))
    
    # Split the image into RGB channels
    b, g, r = cv2.split(image)
    
    # Normalize depth map to match the blue channel (0-255)
    depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX)
    
    # Infuse the depth map into the blue channel (e.g., averaging)
    infused_blue = cv2.addWeighted(b, 0.5, depth_map_normalized, 0.5, 0)
    
    # Merge the channels back
    infused_image = cv2.merge((infused_blue, g, r))
    
    # Save the resulting image
    cv2.imwrite(output_path, infused_image)

# Define the function to preprocess the dataset
def preprocess_dataset(input_image_dir, depth_map_dir, output_dir):
    # Print the current working directory and input directory contents for debugging
    print(f"Current working directory: {os.getcwd()}")
    print(f"Contents of input_image_dir ({input_image_dir}): {os.listdir(input_image_dir)}")
    print(f"Contents of depth_map_dir ({depth_map_dir}): {os.listdir(depth_map_dir)}")

    image_paths = glob(os.path.join(input_image_dir, '*.png'))  # Adjust extension if necessary
    if not image_paths:
        print(f"No images found in directory: {input_image_dir}")
        return

    for image_path in image_paths:
        base_name = os.path.basename(image_path).replace('_texture', '')
        depth_map_path = os.path.join(depth_map_dir, base_name)
        output_path = os.path.join(output_dir, os.path.basename(image_path))
        
        # Check if the depth map exists
        if not os.path.exists(depth_map_path):
            print(f"Warning: Depth map at path {depth_map_path} does not exist.")
            continue

        print(f"Processing {image_path} with depth map {depth_map_path}")
        infuse_depth_into_blue_channel(image_path, depth_map_path, output_path)
        print(f"Saved infused image to {output_path}")

# Specify the directory paths
input_image_dir = './data/original_images_png'
depth_map_dir = './data/depth_images'
output_dir = './SAM/data/infused_images'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Preprocess the dataset
preprocess_dataset(input_image_dir, depth_map_dir, output_dir)

### Apply Functions

In [4]:
# do not use this one, use cell below to adjust
images = read_images_to_array('../data/invotive_data/')
depth = read_images_to_array('../getting_depth_info/data/')
masks, raw = split_images(images)
og_red = masks
images = depth

train_images, train_masks, val_images, val_masks, test_images, test_masks = split_train_val_test(images, masks)

train_images = crop_raw_images(train_images)
train_images = add_padding(train_images, 0, 67)
train_images = crop_images(train_images)
train_masks = crop_masks(train_masks)
train_masks = add_padding(train_masks, 31, 0)
#train_masks = zoom_at(train_masks, 1.156, coord=None)
train_masks = create_binary_masks(train_masks)
train_masks = crop_images(train_masks)

val_images = crop_raw_images(val_images)
val_images = add_padding(val_images, 0, 67)
val_images = crop_images(val_images)
val_masks = crop_masks(val_masks)
val_masks = add_padding(val_masks, 31, 0)
#val_masks = zoom_at(val_masks, 1.156, coord=None)
val_masks = create_binary_masks(val_masks)
val_masks = crop_images(val_masks)

test_images = crop_raw_images(test_images)
test_images = add_padding(test_images, 0, 67)
test_images = crop_images(test_images)
test_masks = crop_masks(test_masks)
test_masks = add_padding(test_masks, 31, 0)
#test_masks = zoom_at(test_masks, 1.156, coord=None)
test_masks = create_binary_masks(test_masks)
test_masks = crop_images(test_masks)

In [None]:
import os
from glob import glob
import cv2
import numpy as np

import os
from glob import glob
import cv2
import numpy as np

def read_images_to_array(directory):
    # Print out the contents of the directory for debugging
    print(f"Contents of directory {directory}: {os.listdir(directory)}")
    
    # Adjust the pattern if needed
    image_paths = glob(os.path.join(directory, '*.png'))
    
    # Print out the found image paths
    print(f"Found {len(image_paths)} images in {directory}")
    
    images = []
    for image_path in image_paths:
        img = cv2.imread(image_path)
        if img is None:
            print(f"Warning: Image at {image_path} could not be read.")
        else:
            images.append(img)
    
    print(f"Loaded {len(images)} images from {directory}")
    return np.array(images)

# Test the function with the infused images directory
depth = read_images_to_array('./SAM/data/infused_images')


# Split images into masks and raw images
def split_images(images):
    # Placeholder for actual logic
    masks = images.copy()
    raw = images.copy()
    print("Split images into masks and raw images.")
    return masks, raw

# Split data into train, val, and test sets
def split_train_val_test(images, masks, val_split=0.2, test_split=0.1):
    total_images = len(images)
    val_size = int(total_images * val_split)
    test_size = int(total_images * test_split)
    train_size = total_images - val_size - test_size
    
    train_images = images[:train_size]
    train_masks = masks[:train_size]
    val_images = images[train_size:train_size + val_size]
    val_masks = masks[train_size:train_size + val_size]
    test_images = images[train_size + val_size:]
    test_masks = masks[train_size + val_size:]
    
    print(f"Split data into {len(train_images)} train, {len(val_images)} val, and {len(test_images)} test images.")
    return train_images, train_masks, val_images, val_masks, test_images, test_masks

# Preprocessing functions
def crop_raw_images(images):
    print("Cropping raw images.")
    return images

def add_padding(images, top, bottom):
    print(f"Adding padding: top={top}, bottom={bottom}.")
    return images

def crop_images(images):
    print("Cropping images.")
    return images

def crop_masks(masks):
    print("Cropping masks.")
    return masks

def create_binary_masks(masks):
    print("Creating binary masks.")
    return masks

# Main processing code
images = read_images_to_array('./data/original_images_png')
depth = read_images_to_array('./SAM/data/infused_images')
                             
masks, raw = split_images(images)
og_red = masks
images = depth

train_images, train_masks, val_images, val_masks, test_images, test_masks = split_train_val_test(images, masks)

train_images = crop_raw_images(train_images)
train_images = add_padding(train_images, 0, 67)
train_images = crop_images(train_images)
train_masks = crop_masks(train_masks)
train_masks = add_padding(train_masks, 31, 0)
train_masks = create_binary_masks(train_masks)
train_masks = crop_images(train_masks)

val_images = crop_raw_images(val_images)
val_images = add_padding(val_images, 0, 67)
val_images = crop_images(val_images)
val_masks = crop_masks(val_masks)
val_masks = add_padding(val_masks, 31, 0)
val_masks = create_binary_masks(val_masks)
val_masks = crop_images(val_masks)

test_images = crop_raw_images(test_images)
test_images = add_padding(test_images, 0, 67)
test_images = crop_images(test_images)
test_masks = crop_masks(test_masks)
test_masks = add_padding(test_masks, 31, 0)
test_masks = create_binary_masks(test_masks)
test_masks = crop_images(test_masks)


### Augmentation

In [14]:
def augment_image_array(image_array_raw, image_array_binary, num_augmentations):
    
    aug_raw = []
    aug_masks = []
    
    for _ in range(num_augmentations):
            for i in range(len(image_array_raw) -1):
                image_raw = image_array_raw[i]
                image_binary = image_array_binary[i]

                flipped_image_raw = cv2.flip(image_raw, 1)
                flipped_image_binary = cv2.flip(image_binary, 1)

                angle = random.uniform(-30, 30)
                (h, w) = flipped_image_raw.shape[:2]
                center = (w // 2, h // 2)

                M = cv2.getRotationMatrix2D(center, angle, 1.0)
                augmented_image_raw = cv2.warpAffine(flipped_image_raw, M, (w, h))
                augmented_image_binary = cv2.warpAffine(flipped_image_binary, M, (w, h))

                aug_raw.append(augmented_image_raw)
                aug_masks.append(augmented_image_binary)

    image_array_raw = np.concatenate((image_array_raw, np.array(aug_raw)))
    image_array_binary = np.concatenate((image_array_binary, np.array(aug_masks)))

    return image_array_raw, image_array_binary

In [15]:
# apply augmentation to arrays 

train_images, train_masks = augment_image_array(train_images, train_masks, 500)
val_images, val_masks = augment_image_array(val_images, val_masks, 500)
test_images, test_masks = augment_image_array(test_images, test_masks, 50)

In [16]:
# print new lengths 

print(len(train_images))
print(len(val_images))
print(len(test_images))

12526
1504
52


### Create the Dataset

In [None]:
# Convert the NumPy arrays to Pillow images and store them in a dictionary
training_dataset_dict = {
    "image": [Image.fromarray(img) for img in train_images],
    "label": [Image.fromarray(mask) for mask in train_masks],
}

val_dataset_dict = {
    "image": [Image.fromarray(img) for img in val_images],
    "label": [Image.fromarray(mask) for mask in val_masks],
}

In [None]:
# Create the dataset using the datasets.Dataset class
training_dataset = Dataset.from_dict(training_dataset_dict)
val_dataset = Dataset.from_dict(val_dataset_dict)

In [17]:
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])

    # get bounding box prompt
    prompt = get_bounding_box(ground_truth_mask)

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

In [None]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

from torch.utils.data import random_split, DataLoader

training_dataset = SAMDataset(dataset=training_dataset, processor=processor)
val_dataset = SAMDataset(dataset=val_dataset, processor=processor)


train_dataloader = DataLoader(dataset=training_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=2, shuffle=False) 

In [None]:
from transformers import SamModel
local_model_path = './hugging-face'

# Load the model from the local path
model = SamModel.from_pretrained(local_model_path)

# Freeze the parameters of the vision and prompt encoders
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad = False

# Verify which parameters are frozen
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

### Training

In [None]:
from torch.optim import Adam
import monai

# Note: Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

#Training loop
num_epochs = 20

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []

    # Training loop
    for batch in tqdm(train_dataloader):
        # Forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # Compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # Backward pass (compute gradients)
        optimizer.zero_grad()
        loss.backward()

        # Optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    # Logging training results
    print(f'EPOCH: {epoch}')
    print(f'Mean training loss: {mean(epoch_losses)}')

    # Validation loop
    model.eval()  # Set model to evaluation mode
    val_losses = []
    with torch.no_grad():  # Disable gradient computation
        for batch in tqdm(val_dataloader):
            # Forward pass
            outputs = model(pixel_values=batch["pixel_values"].to(device),
                            input_boxes=batch["input_boxes"].to(device),
                            multimask_output=False)

            # Compute loss
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
            val_losses.append(loss.item())

    # Logging validation results
    print(f'Mean validation loss: {mean(val_losses)}')
    
# Save the model's state dictionary to a file
torch.save(model.state_dict(), "./models/mito_model_checkpoint6.pth")