In [7]:
from img_util import rate_hair
import pandas as pd
import os
import cv2
import numpy as np
from sklearn.metrics import accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt

# set up directories and files
folder_path = "../data/hair_annotations"
ah = pd.read_csv("../data/annotations.csv")

# get information from the train set
train_set = ah[ah['Group_ID'] != 'C']
true_labels = train_set['Rating_Final']

# initialize variables to store best results
best_acc = 0
best_acc_blur = 0
best_param, best_param_blur = None, None
best_t1, best_t2 = None, None
best_t1_blur, best_t2_blur = None, None

# store all results
results = []  

# iterate over parameter values
for p in range(110, 255, 5):

    # iterate through possible threshold values
    for t1 in np.linspace(0.01, 0.5, 49):
        s = max(t1 + 0.01, 0.09)
        for t2 in np.linspace(s, 1.0, 100 - int(s * 100)):
            
            # list to store predictions (no blur and blur)
            pred_labels = []
            pred_labels_blur = []

            # predict labels
            for i, row in train_set.iterrows():

                # load image
                filename = row['FileID'] + '.png'
                img_path = os.path.join(folder_path, filename)
                img = cv2.imread(img_path)
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                # compute label with and without blur
                _, label, _ = rate_hair(img_rgb, t1= t1, t2= t2, dst= p, blur= False)
                _, label_blur, _ = rate_hair(img_rgb, t1= t1, t2= t2, dst= p)
                pred_labels.append(label)
                pred_labels_blur.append(label_blur)
            
            # compute accuracy, save it if best result yet
            acc = accuracy_score(true_labels, pred_labels)
            acc_blur = accuracy_score(true_labels, pred_labels_blur)

            if acc > best_acc:
                best_acc = acc
                best_param = p
                best_t1, best_t2 = t1, t2

            if acc_blur > best_acc_blur:
                best_acc_blur = acc_blur
                best_param_blur = p
                best_t1_blur, best_t2_blur = t1, t2
            
            # save data
            results.append({
                "t1": t1,
                "t2": t2,
                "dst": p,
                "acc": acc,
                "acc_blur": acc_blur
            })

            # status update
            print(f"Trying dst= {p}, t1= {t1:.2f}, t2= {t2:.2f} | acc= {acc:.3f}, acc_blur={acc_blur:.3f}")

df = pd.DataFrame(results)

# results
print(f"  NO BLUR | Best thresholds: t1 = {best_t1:.3f}, t2 = {best_t2:.3f}; Best parameter: dst = {best_param}; with accuracy = {best_acc:.4f}")
print(f"WITH BLUR | Best thresholds: t1 = {best_t1_blur:.3f}, t2 = {best_t2_blur:.3f}; Best parameter: dst = {best_param_blur}; with accuracy = {best_acc_blur:.4f}")

# plots

# filter for best parameter value
subset = df[df["dst"] == best_param]

# get 2D grid
pivot = subset.pivot(index="t1", columns="t2", values="acc")

# plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(pivot, cmap="viridis", cbar_kws={'label': 'Accuracy'})
plt.title(f"Accuracy heatmap at dst = {best_param}")
plt.xlabel("Upper threshold (t2)")
plt.ylabel("Lower threshold (t1)")
plt.show()

# best accuracy per parameter value
best_per_dst = df.groupby("dst").agg({
    "acc": "max",
    "acc_blur": "max"
}).reset_index()

plt.plot(best_per_dst["dst"], best_per_dst["acc"], label="No Blur")
plt.plot(best_per_dst["dst"], best_per_dst["acc_blur"], label="With Blur")
plt.xlabel("Parameter value (dst)")
plt.ylabel("Best accuracy")
plt.title("Best Accuracy vs. Bright Pixel Cutoff")
plt.legend()
plt.grid(True)
plt.show()



Trying dst= 110, t1= 0.01, t2= 0.09 | acc= 0.618, acc_blur=0.709
Trying dst= 110, t1= 0.01, t2= 0.10 | acc= 0.648, acc_blur=0.709
Trying dst= 110, t1= 0.01, t2= 0.11 | acc= 0.648, acc_blur=0.714
Trying dst= 110, t1= 0.01, t2= 0.12 | acc= 0.668, acc_blur=0.714
Trying dst= 110, t1= 0.01, t2= 0.13 | acc= 0.668, acc_blur=0.698
Trying dst= 110, t1= 0.01, t2= 0.14 | acc= 0.673, acc_blur=0.688
Trying dst= 110, t1= 0.01, t2= 0.15 | acc= 0.678, acc_blur=0.683
Trying dst= 110, t1= 0.01, t2= 0.16 | acc= 0.678, acc_blur=0.678
Trying dst= 110, t1= 0.01, t2= 0.17 | acc= 0.663, acc_blur=0.673
Trying dst= 110, t1= 0.01, t2= 0.18 | acc= 0.648, acc_blur=0.663
Trying dst= 110, t1= 0.01, t2= 0.19 | acc= 0.648, acc_blur=0.658
Trying dst= 110, t1= 0.01, t2= 0.20 | acc= 0.638, acc_blur=0.653
Trying dst= 110, t1= 0.01, t2= 0.21 | acc= 0.648, acc_blur=0.648
Trying dst= 110, t1= 0.01, t2= 0.22 | acc= 0.643, acc_blur=0.648
Trying dst= 110, t1= 0.01, t2= 0.23 | acc= 0.633, acc_blur=0.648
Trying dst= 110, t1= 0.01

KeyboardInterrupt: 