# BASELINE

------------------------------------------

In [None]:
! pip install optuna

In [None]:
import sys
from metrics import *

sys.path.append('data')

In [None]:
import optuna
import torchvision.transforms as transforms

from PIL import Image
from data.processing import convert_to_black_and_white


def objective(trial, images, ground_truth):

    # Define the search space for hyperparameters
    threshold  = trial.suggest_int('threshold ', 0, 255)
    transform = transforms.ToTensor()

    mean_iou = 0

    for image_path, ground_truth_path in zip(images, ground_truth):
        
        image = Image.open(image_path)      
        gt = Image.open(ground_truth_path)
         
        gt = convert_to_black_and_white(image=gt, save_results=False, threshold=1)
        
        gt = transform(gt).squeeze(0)

        width, height = image.size
        
        pixels = list(image.getdata())
        vegetation = []
        
        for pixel in pixels:
            egx_index = calculate_exg_index(pixel)
            if egx_index > threshold:
                vegetation.append(1)
            else:
                vegetation.append(0)

        vegetation = torch.tensor(vegetation)
        vegetation = vegetation.view(height, width)
        
        # Calculate the mean IoU
        mean_iou += calculate_miou(vegetation, gt, 2)
        print(f"Mean IoU: {mean_iou/len(images)}")
        
    return mean_iou / len(images)
            

In [None]:
import os
import optuna
import numpy as np

from data.constants import *
from data.utils import load_data_path


data = load_data_path()
red_edge_data = data[RED_EDGE]
storage_name = "sqlite:///baseline_bw.db"


for key, value in red_edge_data.items():
    
    tiles = value[TILE]
    ground_truth = value[GROUND_TRUTH]
    pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
    
    study = optuna.create_study(study_name=key, direction='maximize', storage=storage_name, pruner=pruner, load_if_exists=True)
    study.optimize(lambda trial: objective(trial, tiles, ground_truth), n_trials=20)

    # Get the best hyperparameters
    try:
        best_params = study.best_params
        print(f"Best Hyperparameters {key}:", best_params)
    except ValueError:
        print(f"No best hyperparameters found for {key}")