In [1]:
import os
import numpy as np
from skimage import io, color, util, feature
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import matplotlib.pyplot as plt
import time


In [4]:
class ImageClassifier:
    def __init__(self, folder_path, desired_height=256, desired_width=256, patch_size=(9,9), n_estimators=100, min_samples_split=5, test_size=0.2, random_state=42):
        self.folder_path = folder_path
        self.desired_height = desired_height
        self.desired_width = desired_width
        self.patch_size = patch_size
        self.n_estimators = n_estimators
        self.min_samples_split = min_samples_split
        self.test_size = test_size
        self.random_state = random_state
        
        # Load images and labels
        self.load_images()
        
        # Prepare data
        self.prepare_data()
        
        # Extract patches
        self.extract_patches()
        
    def load_images(self):
        images = []
        labels = []

        for file_name in os.listdir(self.folder_path):
            if file_name.endswith('.png'):
                image_path = os.path.join(self.folder_path, file_name)
                image = io.imread(image_path)

                if image.shape[2] == 4:
                    image = image[:, :, :3]

                resized_image = resize(image, (self.desired_height, self.desired_width))
                padded_image = np.pad(resized_image, ((0, self.desired_height - resized_image.shape[0]), 
                                                       (0, self.desired_width - resized_image.shape[1]), 
                                                       (0, 0)), mode='constant')

                label = 1 if 'YES' in file_name else 0
                images.append(padded_image)
                labels.append(label)

        self.images = np.array(images)
        self.labels = np.array(labels)
        
    def prepare_data(self):
        num_samples, height, width, channels = self.images.shape
        images_flattened = self.images.reshape(num_samples, height * width * channels)
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(images_flattened, self.labels, test_size=self.test_size, random_state=self.random_state)
        
    def extract_patches(self):
        patched_images = []
        patch_labels = []

        for img, label in zip(self.images, self.labels):
            patches = extract_patches_2d(img, self.patch_size)
            for patch in patches:
                patched_images.append(patch.flatten())
                patch_labels.append(label)

        self.patched_images = np.array(patched_images)
        self.patch_labels = np.array(patch_labels)
        
    def train_classifier(self, X_train, y_train):
        rf_classifier = RandomForestClassifier(min_samples_split=self.min_samples_split, n_estimators=self.n_estimators)
        rf_classifier.fit(X_train, y_train)
        return rf_classifier
        
    def evaluate_classifier(self, rf_classifier, X_train, y_train, X_test, y_test):
        train_accuracy = rf_classifier.score(X_train, y_train)
        test_accuracy = rf_classifier.score(X_test, y_test)
        return train_accuracy, test_accuracy
        
    def train_and_evaluate(self):
        start_time = time.time()
        
        # Train and evaluate unpatched classifier
        unpatched_classifier = self.train_classifier(self.X_train, self.y_train)
        unpatched_train_accuracy, unpatched_test_accuracy = self.evaluate_classifier(unpatched_classifier, self.X_train, self.y_train, self.X_test, self.y_test)
        
        # Train and evaluate patched classifier
        patched_classifier = self.train_classifier(self.patched_images, self.patch_labels)
        patched_train_accuracy, patched_test_accuracy = self.evaluate_classifier(patched_classifier, self.patched_images, self.patch_labels, self.X_test, self.y_test)
        
        end_time = time.time()
        execution_time = end_time - start_time
        
        return (unpatched_train_accuracy, unpatched_test_accuracy), (patched_train_accuracy, patched_test_accuracy), execution_time
        
    def plot_images(self, class_names=['Non-dissected', 'Dissected'], num_columns=5):
        num_test_images = len(self.X_test)
        num_rows = (num_test_images + num_columns - 1) // num_columns
        fig, axes = plt.subplots(num_rows, num_columns, figsize=(15, 3*num_rows))

        if num_test_images == 1:
            axes = [axes]

        for i in range(num_test_images):
            test_image = self.X_test[i].reshape(self.desired_height, self.desired_width, channels)
            true_label = self.y_test[i]

            predicted_label = rf_classifier.predict([self.X_test[i]])[0]

            row_index = i // num_columns
            col_index = i % num_columns
            axes[row_index][col_index].imshow(test_image)
            axes[row_index][col_index].set_title(f"True Label: {class_names[true_label]}\nPredicted Label: {class_names[predicted_label]}")
            axes[row_index][col_index].axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
# Usage
folder_path = "/Users/louloules/LOCAL_DISK_PC/ML_Medical_Imaging/ProjectPerso/Project Images"
classifier = ImageClassifier(folder_path)
(unpatched_train_accuracy, unpatched_test_accuracy), (patched_train_accuracy, patched_test_accuracy), execution_time = classifier.train_and_evaluate()
print("Unpatched Classifier Accuracy:", unpatched_train_accuracy, unpatched_test_accuracy)
print("Patched Classifier Accuracy:", patched_train_accuracy, patched_test_accuracy)
print("Execution Time:", execution_time, "seconds")
classifier.plot_images()