In [None]:
import sys
import os
from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QVBoxLayout, QHBoxLayout, 
                             QWidget, QFileDialog, QLabel, QListWidget, QListWidgetItem, QTextEdit, 
                             QSlider, QScrollArea, QMessageBox, QColorDialog, QLineEdit)
from PyQt5.QtGui import QPixmap, QImage, QFont
from PyQt5.QtCore import pyqtSignal, Qt
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import csv
import math
import matplotlib.pyplot as plt
from datetime import datetime

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
  

class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, padding=1)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))

        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class ColorAndClassificationApp(QMainWindow):
    def __init__(self, model):
        super().__init__()
        self.setWindowTitle("Cell Detection Application")
        self.setGeometry(100, 100, 2000, 800)
        self.model = model  # image classification model
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        # font size(adjust specific layout)
        font = QFont()
        font.setPointSize(20)

        # Main layout
        self.main_layout = QHBoxLayout()

        # Left Panel
        self.left_layout = QVBoxLayout()
        self.open_dir_button = QPushButton("Open Folder")
        self.open_dir_button.clicked.connect(self.open_folder)
        self.image_list = QListWidget()
        self.detection_button = QPushButton("Detection")
        self.detection_button.clicked.connect(self.perform_detection)
        self.detection_all_button = QPushButton("Detection all images")
        self.detection_all_button.clicked.connect(self.perform_detection_all_images)
        self.left_layout.addWidget(self.open_dir_button)
        self.left_layout.addWidget(self.image_list)
        self.left_layout.addWidget(self.detection_button)
        self.left_layout.addWidget(self.detection_all_button)
        
        self.open_dir_button.setMinimumSize(150, 80)
        self.detection_button.setMinimumSize(150, 80)
        self.detection_all_button.setMinimumSize(150, 80)
        
        # Center Panel
        self.center_layout = QVBoxLayout()
        self.scroll_area = QScrollArea()
        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignCenter)
        self.scroll_area.setWidget(self.image_label)
        self.scroll_area.setWidgetResizable(True)
        self.zoom_slider = QSlider(Qt.Horizontal)
        self.zoom_slider.setMinimum(50)
        self.zoom_slider.setMaximum(100)
        self.zoom_slider.setValue(50)
        self.zoom_slider.valueChanged.connect(self.zoom_image)
        self.center_layout.addWidget(self.scroll_area)
        self.center_layout.addWidget(self.zoom_slider)

        # Right Panel
        self.right_layout = QVBoxLayout()

        # Detection 결과를 위한 QTextEdit와 라벨
        self.detection_results_area = QTextEdit()
        self.detection_results_area.setReadOnly(True)
        self.right_layout.addWidget(QLabel("Detection Results:"))
        self.right_layout.addWidget(self.detection_results_area)

        # Add panels to main layout
        self.main_layout.addLayout(self.left_layout, 1)
        self.main_layout.addLayout(self.center_layout, 3)
        self.main_layout.addLayout(self.right_layout, 1) 

        # font size adjust
        self.image_list.setFont(font)
        self.detection_results_area.setFont(font)

        central_widget = QWidget()
        central_widget.setLayout(self.main_layout)
        self.setCentralWidget(central_widget)

        # Detection parameters input fields
        self.lower_bound_input = QLineEdit("70,200,100")
        self.upper_bound_input = QLineEdit("140,255,255")
        self.contour_area_input = QLineEdit("100")
        
        self.image_list.itemClicked.connect(self.display_image)
        
       #pre-setting before detection
        self.processed_pixmap = None
        
        self.last_mouse_position = None
        
        self.zoom_percentage_label = QLabel("50%")  # Add a label to display zoom level
        self.center_layout.addWidget(self.zoom_percentage_label)  # Add the label to the layout
        
        self.selected_object_info = None
        self.detected_objects = []
        self.detected_pixmap = {}
        self.is_original_image_displayed = True

    def open_folder(self):
        file_names, _ = QFileDialog.getOpenFileNames(
            self, "Select Images", "", "Image Files (*.png *.jpg *.jpeg *.bmp *.gif)"
        )
        if file_names:
            self.image_list.clear()
            for file_name in file_names:
                item = QListWidgetItem(os.path.basename(file_name))
                item.setData(Qt.UserRole, file_name)
                self.image_list.addItem(item)
                
    def wheelEvent(self, event):
        if QApplication.keyboardModifiers() == Qt.ControlModifier:
            delta = event.angleDelta()
            if delta.y() > 0:
                self.zoom_in()
            else:
                self.zoom_out()
            event.accept()
        else:
            event.ignore()

    def zoom_in(self):
        current_zoom = self.zoom_slider.value()
        new_zoom = min(current_zoom + 10, self.zoom_slider.maximum())
        self.zoom_slider.setValue(new_zoom)

    def zoom_out(self):
        current_zoom = self.zoom_slider.value()
        new_zoom = max(current_zoom - 10, self.zoom_slider.minimum())
        self.zoom_slider.setValue(new_zoom)

    def zoom_image(self, value):
        if self.processed_pixmap:
            pixmap_to_zoom = self.processed_pixmap
        else:
            pixmap_to_zoom = self.original_pixmap

        if pixmap_to_zoom:
            new_width = pixmap_to_zoom.width() * value / 100
            new_height = pixmap_to_zoom.height() * value / 100
            resized_pixmap = pixmap_to_zoom.scaled(new_width, new_height, Qt.KeepAspectRatio)
            self.image_label.setPixmap(resized_pixmap)
            self.image_label.adjustSize()
            zoom_percentage = (value / self.zoom_slider.maximum()) * 100
            self.zoom_percentage_label.setText(f"{zoom_percentage:.0f}%")
            
    
    def clear_detection_results(self):
        # Reset the image to its original without overlays
        if self.original_pixmap:
            self.image_label.setPixmap(self.original_pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio))

        # Clear the list of detected objects or any related display components
        self.detected_objects = []
        self.results_area.clear()  # Assuming this is where detection results are displayed                
            
    def display_image(self, item):
        image_path = item.data(Qt.UserRole)
        filename = os.path.basename(image_path)
        
        if filename in self.detected_pixmap:
            self.original_pixmap = self.detected_pixmap[filename]
            pixmap = self.original_pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio)            
        else:
            self.original_pixmap = QPixmap(image_path)
            pixmap = self.original_pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio)

        self.image_label.setPixmap(pixmap)

        # reset image information
        self.selected_object_info = None
        
    def perform_detection_all_images(self):
        for index in range(self.image_list.count()):
            item = self.image_list.item(index)
            image_path = item.data(Qt.UserRole)
            filename = os.path.basename(image_path)
            detected_objects, original_image = self.count_objects_by_color(image_path)

            if original_image is None:
                continue

            # Create a copy of the original image to draw rectangles
            image_with_rectangles = original_image.copy()
            detection_data = {'Total Cells': 0, 'In Cells': 0, 'Out Cells': 0, 'Fusion Index': 0.0}
            detection_data['Total Cells'] = len(detected_objects)
            in_count, out_count = 0, 0

            for (x, y, w, h) in detected_objects:
                object_img = original_image[y:y+h, x:x+w]
                object_img = cv2.cvtColor(object_img, cv2.COLOR_BGR2RGB)
                object_img_pil = Image.fromarray(object_img)
                object_img_tensor = self.transform(object_img_pil).unsqueeze(0).to(self.device)

                self.model.eval()
                with torch.no_grad():
                    output = self.model(object_img_tensor)
                    predicted_prob = torch.sigmoid(output).item()

                predicted_class = 0 if predicted_prob >= 0.5 else 1
                rect_color = (255, 0, 0) if predicted_class == 1 else (0, 0, 255)

                # Draw rectangles on the copy of the image
                cv2.rectangle(image_with_rectangles, (x, y), (x + w, y + h), rect_color, 2)

                if predicted_class == 1:
                    in_count += 1
                else:
                    out_count += 1
                    
                self.detected_objects.append((x, y, w, h, predicted_class))
            
            detection_data['In Cells'] = in_count
            detection_data['Out Cells'] = out_count
            detection_data['Fusion Index'] = in_count / detection_data['Total Cells'] if detection_data['Total Cells'] > 0 else 0

            # Convert the modified image for display
            image_with_rectangles = cv2.cvtColor(image_with_rectangles, cv2.COLOR_BGR2RGB)
            q_image = QImage(image_with_rectangles.data, image_with_rectangles.shape[1], image_with_rectangles.shape[0], QImage.Format_RGB888)
            pixmap = QPixmap.fromImage(q_image)
            
            # Update the detected_pixmap dictionary
            self.processed_pixmap = pixmap
            self.detected_pixmap[filename] = pixmap
            self.image_label.setPixmap(pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio))
            self.detection_results_area.append(f"Detected objects: {len(detected_objects)}")
            self.detection_results_area.append(f"Detected 'IN': {in_count}\nDetected 'OUT': {out_count}\n")
            self.detection_results_area.append(f"Fusion Index (IN/Total) : {in_count/len(detected_objects)}")
            self.export_to_csv(detection_data)

        QMessageBox.information(self, "Detection Completed", "Detection has been completed for all selected images.")


    def perform_detection(self):
        lower_blue = np.array([70, 200, 100])
        upper_blue = np.array([140, 255, 255])
        
        current_item = self.image_list.currentItem()
        if current_item is None:
            QMessageBox.information(self, "Info", "No image selected.")
            return
        
        image_path = current_item.data(Qt.UserRole)
        filename = os.path.basename(image_path)
        detected_objects, original_image = self.count_objects_by_color(image_path)

        if original_image is None:
            return

        # Create a copy of the original image to draw rectangles
        image_with_rectangles = original_image.copy()
        
        detection_data = {
            'Image Name': filename,
            'Total Cells': 0,
            'In Cells': 0,
            'Out Cells': 0,
            'Fusion Index': 0.0
            }

#         detection_data['Total Cells'] = len(detected_objects)
        in_count, out_count = 0, 0
        for i, (x, y, w, h) in enumerate(detected_objects):
            object_img = original_image[y:y+h, x:x+w]
            hsv_cropped_img = cv2.cvtColor(object_img, cv2.COLOR_BGR2HSV)
            object_img = cv2.cvtColor(object_img, cv2.COLOR_BGR2RGB)

            mask = cv2.inRange(hsv_cropped_img, lower_blue, upper_blue)
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            object_img_pil = Image.fromarray(object_img)
            object_img_tensor = self.transform(object_img_pil).unsqueeze(0).to(self.device)       
            
            self.model.eval()
            with torch.no_grad():
                output = self.model(object_img_tensor.to(device))
                predicted_prob = torch.sigmoid(output).item()

            predicted_class = 0 if predicted_prob >= 0.5 else 1
            rect_color = (255, 0, 0) if predicted_class == 1 else (0, 0, 255)

            # Draw rectangles on the copy of the image
            cv2.rectangle(image_with_rectangles, (x, y), (x + w, y + h), rect_color, 2)
            total_area = sum(cv2.contourArea(contour) for contour in contours if cv2.contourArea(contour) > 100)
            
            nuclei_count = math.ceil(total_area / 500)  # Estimate the count of nuclei based on average area            

            
            if predicted_class == 1:
                in_count += nuclei_count
            else:
                out_count += nuclei_count
                
            self.detected_objects.append((x, y, w, h, predicted_class))
    
        detection_data['In Cells'] = in_count
        detection_data['Out Cells'] = out_count
        detection_data['Total Cells'] = in_count + out_count
        detection_data['Fusion Index'] = in_count / (in_count + out_count) if len(detected_objects) > 0 else 0

        # Convert the modified image for display
        image_with_rectangles = cv2.cvtColor(image_with_rectangles, cv2.COLOR_BGR2RGB)
        height, width, channel = image_with_rectangles.shape
        bytes_per_line = 3 * width
        q_image = QImage(image_with_rectangles.data, width, height, bytes_per_line, QImage.Format_RGB888)
        pixmap = QPixmap.fromImage(q_image)

        # Update the processed pixmap
        self.processed_pixmap = pixmap
        self.detected_pixmap[filename] = pixmap

        self.image_label.setPixmap(pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio))
        self.detection_results_area.append(f"Detected objects: {in_count + out_count}")
        self.detection_results_area.append(f"Detected 'IN': {in_count}\nDetected 'OUT': {out_count}\n")
        self.detection_results_area.append(f"Fusion Index(IN/Total) :{in_count/len(detected_objects)}\n")
        self.export_to_csv(detection_data)

    def count_objects_by_color(self, image_path):
        # Convert input text to numpy arrays
        lower_bound_values = [int(v) for v in self.lower_bound_input.text().split(',')]
        upper_bound_values = [int(v) for v in self.upper_bound_input.text().split(',')]
        contour_area_threshold = int(self.contour_area_input.text())

        # Here we define 'lower_bound' and 'upper_bound' within the method's scope
        lower_bound = np.array(lower_bound_values, dtype="uint8")
        upper_bound = np.array(upper_bound_values, dtype="uint8")

        # Load the image
        image = cv2.imread(image_path)
        if image is None:
            print(f"Error loading image: {image_path}")
            return [], None

        # Convert to HSV and apply mask
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        mask = cv2.inRange(hsv, lower_bound, upper_bound)  # Use the locally defined variables
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        detected_objects = []
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > contour_area_threshold:
                x, y, w, h = cv2.boundingRect(contour)
                detected_objects.append((x, y, w, h))

        return detected_objects, image

#     def export_to_csv(self, data):
#         filename = 'detection_results.csv'
#         file_exists = os.path.isfile(filename)

#         with open(filename, 'a', newline='') as csvfile:
#             fieldnames = ['Image Name', 'Total Cells', 'In Cells', 'Out Cells', 'Fusion Index']
#             writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

#             if not file_exists:
#                 writer.writeheader()  # Write header only once

#             writer.writerow(data)

    def export_to_csv(self, data):
        filename = 'detection_results.csv'
        file_exists = os.path.isfile(filename)

        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        data['Time'] = timestamp  # Add timestamp to the data dictionary

        with open(filename, 'a', newline='') as csvfile:
            fieldnames = ['Time', 'Image Name', 'Total Cells', 'In Cells', 'Out Cells', 'Fusion Index']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

            if not file_exists:
                writer.writeheader()  # Write header only once

            writer.writerow(data)


def main():
    model = CustomCNN().to(device)
    model.load_state_dict(torch.load('./model.pth', map_location= device))
    app = QApplication(sys.argv)
    font = QFont("Times New Roman", 12)
    font.setPointSize(20)
    #font.setWeight(QFont.Bold)
    app.setFont(font)
    main_window = ColorAndClassificationApp(model)
    main_window.show()
    sys.exit(app.exec_())

if __name__ == '__main__':
    main()