# CLUSTERING

------------------------------------------

## Common Utils

In [None]:
import sys
from constants import *
from common import *
from metrics import *

sys.path.append("data")
from processing import *
from data.constants import *
from data.utils import load_split_data

## Load Data

In [None]:
import os
directory_path = os.path.join(EXPERIMENT_PATH, CLUSTERING)
os.makedirs(directory_path, exist_ok=True)

In [None]:
x_train, y_train, x_test, y_test = load_split_data(open_image=True)

## Clustering

In [None]:
from sklearn.exceptions import ConvergenceWarning
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning
from tqdm import TqdmWarning

simplefilter("ignore", category=ConvergenceWarning)
simplefilter("ignore", category=UserWarning)
simplefilter("ignore", category=TqdmWarning)

### Experiments

In [None]:
from comet_ml import Experiment
from comet_ml.integration.sklearn import log_model
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from torchvision import transforms
from PIL import Image


def clustering_pipeline(images,
                        ground_truths,
                        model, 
                        experiment_name,
                        directory_path,
                        num_clusters:int,
                        model_tags: list = [],
                        target_label:int = 1, 
                        default_channels:int = 1,
                        num_classes:int = 2,
                        plot_results: bool = False):

    try:
        os.makedirs(directory_path, exist_ok=True)
        experiment = Experiment(
            api_key="eI2MJOa5W8d1PcAvxhmyP5VGt",
            project_name="weedmap-image-segmentation",
            workspace="francesco-ranieri"
        )
        
        experiment.set_name(experiment_name)
        experiment.add_tags(model_tags)
        
        experiment.log_parameters(
            {'n_clusters': num_clusters, 
            'n_init': 'auto', 
            'random_state': SEED,
            'target_label': target_label,
            'default_channels': default_channels,
            'num_classes': num_classes,
            }
        )
        
        images_dirs = []
        num_tiles = 0
        miou = 0

        for drone in images:
            os.mkdir(os.path.join(directory_path, drone))
            print(f'- Processing drone {drone}')
            
            for image_dir in images[drone]:
                images_dirs.append(image_dir)
                os.mkdir(os.path.join(directory_path, drone, image_dir))        
                print(f'-- Processing image directory {image_dir}')
                
                for index, image in enumerate(images[drone][image_dir]):
                    print(f'--- Processing image {index}')
                    
                    image = flat_image(image)
                    ground_truth = ground_truths[drone][image_dir][index]
                    
                    file_name = f'frame{"{:04d}".format(index)}.png'
                    image_path = os.path.join(directory_path, drone, image_dir, file_name)
                    
                    predicted_labels, gt_labels = clustering(image, ground_truth, kmeans, image_path)
                    miou += evaluate_clustering(predicted_labels, gt_labels, num_classes)
                    num_tiles += 1
        
        
        experiment.log_parameters({'image_dirs': '_'.join(images_dirs)})
        miou_overall = miou / num_tiles
        print(f'--- MIOU: {miou_overall}')
        experiment.log_metric('mIOU', miou_overall)
        
        log_model(
            experiment = experiment,
            model = model,
            model_name = experiment_name,
        )
        
        experiment.end()

    except Exception as e:
        delete_experiment_files(directory_path)
        print(e)
        experiment.end()
        raise e

def clustering(image,
               ground_truth,
               model,
               target_label:int  = 1,
               image_path: str = None,
               plot_results: bool = False):
     
    transform = transforms.Compose([transforms.ToTensor()])
    model.fit(image)
    
    gt_labels = transform(ground_truth)
    height, width = ground_truth.height, ground_truth.width
    
    labels = model.predict(image)
    labels = labels.reshape((height, width))    
    # Save only the green cluster (assuming it's labeled as [target_label])
    green_segment = (labels == target_label) * 255  # Multiply by 255 to convert boolean to integer (0 or 255)
    # Create a PIL image from the segmented green pixels
    predicted = Image.fromarray(green_segment.astype(np.uint8))
    
    predicted_labels = transform(predicted)
    
    if image_path:
        print(f"Saving image to {image_path}")
        predicted.save(image_path)
    
    if plot_results:
        plt.figure(figsize=(8, 6))
        plt.imshow(green_segment, cmap='Greens')  # Adjust the colormap based on your preference
        plt.axis('off')
        plt.title('Green Pixels')
        plt.show()
    
    return predicted_labels, gt_labels


def evaluate_clustering(predicted_labels, gt_labels, num_classes:int = 2):
    return calculate_miou(predicted_labels, gt_labels, num_classes)


def flat_image(image, default_channels = 1):

    image_array = np.array(image)
    shape = image_array.shape
    height, width = shape[:2]
    image_reshaped = image_array.reshape((height * width), default_channels)
    
    return image_reshaped

#### KMeans

In [None]:
kmeas_path = os.path.join(directory_path, KMEANS)
model_tags = ['kmeans']

In [None]:
# Training data

data_type = 'training'
num_clusters = 2
kmeans = KMeans(n_clusters=num_clusters, n_init='auto', random_state=SEED)
experiment_name = f'kmeans_{data_type}_data'
experiment_path = os.path.join(kmeans_path, data_type)

clustering_pipeline(images = x_train,
                    ground_truths = y_train,
                    model = kmeans, 
                    experiment_name = experiment_name,
                    num_clusters = num_clusters,
                    directory_path = experiment_path,
                    model_tags = model_tags + [data_type])

In [None]:
# Testing data

data_type = 'test'
num_clusters = 2
kmeans = KMeans(n_clusters=num_clusters, n_init='auto', random_state=SEED)
experiment_name = f'kmeans_{data_type}_data'
experiment_path = os.path.join(kmeans_path, data_type)

clustering_pipeline(images = x_test,
                    ground_truths = y_test,
                    model = kmeans, 
                    experiment_name = experiment_name,
                    num_clusters = num_clusters,
                    directory_path = experiment_path,
                    model_tags = model_tags + [data_type])