In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
import os
from typing import Tuple, List, Dict
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class K_Means:
    def __init__(self, k: int = 2, tol: float = 0.001, max_iter: int = 300, verbose: bool = False):
        
        self.k = k
        self.tol = tol
        self.max_iter = max_iter
        self.verbose = verbose
        self.centroids = None
        self.classifications = None
        
    def fit(self, data: np.ndarray) -> None:
       
        if not isinstance(data, np.ndarray):
            data = np.array(data)
            
        if data.size == 0:
            raise ValueError("Empty dataset provided")
            
        self.centroids = self.initialize_centroids(data)
        
        for i in range(self.max_iter):
            self.classifications = {j: [] for j in range(self.k)}
            
            # Vectorized distance calculation
            distances = np.array([[np.linalg.norm(x - self.centroids[centroid]) 
                                 for centroid in self.centroids] 
                                for x in data])
            classifications = np.argmin(distances, axis=1)
            
            # Update classifications
            for j, point in enumerate(data):
                self.classifications[classifications[j]].append(point)
                
            prev_centroids = dict(self.centroids)
            
            # Update centroids
            for classification in self.classifications:
                points = np.array(self.classifications[classification])
                if len(points) > 0:
                    self.centroids[classification] = np.mean(points, axis=0)
            
            # Check convergence
            changes = [np.sum(abs(self.centroids[c] - prev_centroids[c])) 
                      for c in self.centroids]
            if self.verbose:
                logger.info(f"Iteration {i+1}: Max centroid change: {max(changes):.6f}")
                
            if all(change <= self.tol for change in changes):
                if self.verbose:
                    logger.info(f"Converged after {i+1} iterations")
                break
                
        return self

    def initialize_centroids(self, data: np.ndarray) -> Dict[int, np.ndarray]:
        
        centroids = [data[np.random.randint(0, len(data))]]
        
        for _ in range(1, self.k):
            distances = np.array([min([np.linalg.norm(point - centroid) 
                                    for centroid in centroids]) 
                                for point in data])
            probabilities = distances / distances.sum()
            cumulative_probs = np.cumsum(probabilities)
            r = np.random.random()
            
            for j, p in enumerate(cumulative_probs):
                if r < p:
                    centroids.append(data[j])
                    break
                    
        return {i: centroids[i] for i in range(len(centroids))}

    def predict(self, data: np.ndarray) -> np.ndarray:
        
        if not isinstance(data, np.ndarray):
            data = np.array(data)
            
        if self.centroids is None:
            raise ValueError("Model not fitted yet")
            
        if len(data.shape) == 1:
            data = data.reshape(1, -1)
            
        distances = np.array([[np.linalg.norm(x - self.centroids[centroid]) 
                             for centroid in self.centroids] 
                            for x in data])
        return np.argmin(distances, axis=1)

def preprocess_image(image: np.ndarray) -> np.ndarray:
    
    # Convert to float32 for better precision
    img_float = image.astype(np.float32) / 255.0
    
    # Convert to RGB if needed
    if len(image.shape) == 3 and image.shape[2] == 3:
        img_rgb = cv2.cvtColor(img_float, cv2.COLOR_BGR2RGB)
    else:
        img_rgb = img_float
        
    # Add gaussian blur to reduce noise
    img_blur = cv2.GaussianBlur(img_rgb, (3, 3), 0)
    
    return img_blur.reshape((-1, 3))

def segment_image(image_path: str, k: int = 2) -> Tuple[np.ndarray, float, List[np.ndarray]]:
    
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
        
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not read image: {image_path}")
        
    # Preprocess image
    vectorized = preprocess_image(image)
    
    # Fit K-means
    model = K_Means(k=k, verbose=True)
    model.fit(vectorized)
    
    # Get predictions
    labels = model.predict(vectorized)
    centers = [model.centroids[i] for i in range(k)]
    
    # Reconstruct segmented image
    segmented = np.zeros_like(vectorized)
    for i in range(k):
        mask = (labels == i)
        segmented[mask] = centers[i]
        
    segmented = (segmented.reshape(image.shape) * 255).astype(np.uint8)
    
    # Calculate silhouette score
    silhouette = silhouette_score(vectorized, labels)
    
    return segmented, silhouette, centers

def visualize_results(image_path: str, k: int = 2) -> None:
    
    image = cv2.imread(image_path)
    segmented, silhouette, centers = segment_image(image_path, k)
    
    logger.info(f"Silhouette Score for k={k}: {silhouette:.4f}")
    logger.info("\nCluster Centroid Colors (RGB):")
    for i, center in enumerate(centers):
        logger.info(f"Cluster {i}: {(center * 255).astype(int)}")
    
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(122)
    plt.imshow(segmented)
    plt.title(f'Segmented Image (K={k})')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    image_path = r"C:\Users\hanna\ComputerVision\Skrmbillede114304.png"
    visualize_results(image_path, k=5)


INFO:__main__:Iteration 1: Max centroid change: 0.124090
INFO:__main__:Iteration 2: Max centroid change: 0.041720
INFO:__main__:Iteration 3: Max centroid change: 0.026505
INFO:__main__:Iteration 4: Max centroid change: 0.014041
INFO:__main__:Iteration 5: Max centroid change: 0.007110
INFO:__main__:Iteration 6: Max centroid change: 0.004113
INFO:__main__:Iteration 7: Max centroid change: 0.002630
INFO:__main__:Iteration 8: Max centroid change: 0.001836
INFO:__main__:Iteration 9: Max centroid change: 0.001333
INFO:__main__:Iteration 10: Max centroid change: 0.001109
INFO:__main__:Iteration 11: Max centroid change: 0.000896
INFO:__main__:Converged after 11 iterations
