In [None]:
import numpy as np
import tifffile
import cv2
import json
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import models, transforms
from efficientnet_pytorch import EfficientNet
import segmentation_models_pytorch as smp

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

""""""
NN = models.efficientnet_b7(weights="IMAGENET1K_V1")

#resnet not trained later 
for param in NN.parameters():
    param.requires_grad = False
    
num_features = NN.classifier[-1].in_features
    
mlp = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(p=0.3),
    nn.Linear(512, 64),
    nn.ReLU(),
    nn.Dropout(p=0.3),
    nn.Linear(64, 1),
    #nn.Softmax(dim=1)
)

NN.classifier = mlp
    
class ImageProcessor:
    def __init__(self, model_path='../Data/model_efficientNetB7.pt'):
        # Load the cell count model
        self.segmentation_and_regression_model = NN
        state_dict = torch.load(model_path)
        self.segmentation_and_regression_model.load_state_dict(state_dict)
        self.segmentation_and_regression_model.eval()
       
        self.transform = transforms.Compose([
            transforms.ToPILImage(),  
            transforms.Resize((128, 128)),  # Resize the image for consistant input shape
            transforms.ToTensor(), 
        ])

    def process_image(self, image_path):
        # Load the image from the .tif file
        image = tifffile.imread(image_path)

        # Apply your defined function to separate slices
        droplet_segments, x_coordinates = self.identify_droplet_segments(image)

        # Initialize an empty list to store cell counts and x coordinates
        results = []

        # Iterate over slices and pass through the pretrained neural network
        for i, slice in enumerate(droplet_segments):
            print(slice)
            droplet_segment = cv2.resize(slice, (128,128), interpolation = cv2.INTER_LINEAR)
            droplet_segment_tensor = torch.from_numpy(droplet_segment)

            with torch.no_grad():
                output_cell_count = self.segmentation_and_regression_model(droplet_segment_tensor)

            # Extract ouput
            cell_count = torch.sum(output_cell_count).item()

            # Append results to the list
            results.append({"x_coordinate": x_coordinates[i], "cell_count": cell_count})

            return results

    def identify_droplet_segments(self, image):

        # 1. Convert the image to greyscale
        grey_image = np.dot(image[..., :3], [0.299, 0.587, 0.114])

        # 2. Apply filtering
        canny_filtered_image = self.canny_filter(grey_image)

        # 3. Apply slice on x to get a list of the slices
        segments_x = self.slice_along_x(canny_filtered_image)

        # setup segments
        image_segments = []
        mean_x_value = []

        for segment_x in segments_x:
            coordinates_y = self.slice_along_y(image = canny_filtered_image, coordinates= segment_x)
            original_segment = grey_image[coordinates_y[0][0]:coordinates_y[0][1],segment_x[0]:segment_x[1]]
            original_segment = original_segment.astype(int)
            self.visualize_image(original_segment)
            image_segments.append(original_segment)

        return image_segments, mean_x_value

    def visualize_image(self, image):
        plt.imshow(image, cmap='gray')
        plt.show()

    def canny_filter(self, gray_arr):
        # Replace Sobel filter with Canny edge detection
        edges = cv2.Canny(gray_arr.astype(np.uint8), 50, 150)  # You may need to adjust the threshold values

        return edges

    def slice_along_x(self, image, threshold_start=40, threshold_end=40, min_segment_length=500, max_segment_length=800):
        # Calculate variance along the x-axis
        x_variances = np.mean(image, axis=0)

        # Identify segments where the variance exceeds the threshold
        segments = []
        droplet_started = False
        start_index = 0

        for i, value in enumerate(x_variances):
            if value > threshold_start and not droplet_started:
                droplet_started = True
                start_index = i
            elif value <= threshold_end and droplet_started:
                droplet_started = False
                end_index = i - 1
                segment_length = end_index - start_index

                # Check if the segment length is within the desired range
                if min_segment_length <= segment_length <= max_segment_length:
                    segments.append((start_index, end_index))

        # If a droplet continues to the end of the image, consider it
        if droplet_started:
            end_index = len(x_variances) - 1
            segment_length = end_index - start_index

            # Check if the segment length is within the desired range
            if min_segment_length <= segment_length <= max_segment_length:
                segments.append((start_index, end_index))

        # Extract slices based on the identified positions
        coordinates = [(start, end) for start, end in segments]

        return coordinates

    def slice_along_y(self, image, coordinates, y_threshold_start=40, y_threshold_end=40,
                      min_segment_length=450, max_segment_length=800):
        # Calculate variance along the x-axis
        image = image[:,coordinates[0]:coordinates[1]]
        y_variances = np.mean(image, axis=1)

        # Identify segments where the variance exceeds the threshold
        segments = []
        droplet_started = False
        start_index = 0

        for i, value in enumerate(y_variances):
            if value > y_threshold_start and not droplet_started:
                droplet_started = True
                start_index = i
            elif value <= y_threshold_end and droplet_started:
                droplet_started = False
                end_index = i - 1
                segment_length = end_index - start_index

                # Check if the segment length is within the desired range
                if min_segment_length <= segment_length <= max_segment_length:
                    segments.append((start_index, end_index))

        # Extract slices based on the identified positions
        coordinates_y = [(start, end) for start, end in segments]

        return coordinates_y


# Example usage:
if __name__ == "__main__":
    inputPath = ""
    outputPath = ""
    DropletPositions = []
    CellCounts = []

    # TO ADD: for loop here, looping through each file in inputPath
    image_path = "../Data/Input/15.tif"
    image_processor = ImageProcessor()
    results = image_processor.process_image(image_path)

    # Append results to Droplet Positions and Cell Counts
    for result in results:
        DropletPositions.append(result["x_coordinate"])
        CellCounts.append(result["cell_count"])


    # After: Create a json file from DropletPositions and Cell Counts, Sequence
