In [2]:
#######################
### LOAD IN MODULES ###
#######################

import cv2
from scipy.interpolate import interp1d
from sklearn.decomposition import PCA
from scipy.spatial import procrustes
from scipy.spatial import ConvexHull # Not directly used in the provided segment, but kept for completeness
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # Not directly used, but kept
from sklearn.metrics import confusion_matrix # Not directly used, but kept
import scipy.stats as stats # Not directly used, but kept
import statsmodels.stats.multitest as multitest # Not directly used, but kept
import itertools # Not directly used, but kept
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import numpy as np
import math
import pandas as pd
import seaborn as sns # Used, but for plotting commented out
from matplotlib.colors import LogNorm # Not directly used, but kept
import phate # Not directly used, but kept
import scprep # Not directly used, but kept
import h5py
import pickle # Not directly used, but kept
import os

#################
### FUNCTIONS ###
#################

def angle_between(p1, p2, p3):
    """
    define a function to find the angle between 3 points anti-clockwise in degrees, p2 being the vertex
    inputs: three angle points, as tuples
    output: angle in degrees
    """
    x1, y1 = p1
    x2, y2 = p2
    x3, y3 = p3
    deg1 = (360 + math.degrees(math.atan2(x1 - x2, y1 - y2))) % 360
    deg2 = (360 + math.degrees(math.atan2(x3 - x2, y3 - y2))) % 360
    return deg2 - deg1 if deg1 <= deg2 else 360 - (deg1 - deg2)

def rotate_points(xvals, yvals, degrees):
    """"
    define a function to rotate 2D x and y coordinate points around the origin
    inputs: x and y vals (can take pandas dataframe columns) and the degrees (positive, anticlockwise) to rotate
    outputs: rotated and y vals
    """
    angle_to_move = 90 - degrees
    rads = np.deg2rad(angle_to_move)

    new_xvals = xvals * np.cos(rads) - yvals * np.sin(rads)
    new_yvals = xvals * np.sin(rads) + yvals * np.cos(rads)

    return new_xvals, new_yvals

def interpolation(x, y, number):
    """
    define a function to return equally spaced, interpolated points for a given polyline
    inputs: arrays of x and y values for a polyline, number of points to interpolate
    ouputs: interpolated points along the polyline, inclusive of start and end points
    """
    if len(x) < 2 or len(y) < 2:
        if np.all(x == x[0]) and np.all(y == y[0]):
            return np.full(number, x[0]), np.full(number, y[0])
        elif len(x) == 1:
            return np.full(number, x[0]), np.full(number, y[0])
        else:
            pass

    distance = np.cumsum(np.sqrt(np.ediff1d(x, to_begin=0)**2 + np.ediff1d(y, to_begin=0)**2))

    if distance[-1] == 0:
        return np.full(number, x[0]), np.full(number, y[0])

    distance = distance / distance[-1]

    fx, fy = interp1d(distance, x), interp1d(distance, y)

    alpha = np.linspace(0, 1, number)
    x_regular, y_regular = fx(alpha), fy(alpha)

    return x_regular, y_regular

def euclid_dist(x1, y1, x2, y2):
    """
    define a function to return the euclidean distance between two points
    inputs: x and y values of the two points
    output: the eulidean distance
    """
    return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)

def poly_area(x, y):
    """
    define a function to calculate the area of a polygon using the shoelace algorithm
    inputs: separate numpy arrays of x and y coordinate values
    outputs: the area of the polygon
    """
    return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))

def gpa_mean(leaf_arr, landmark_num, dim_num):
    """
    define a function that given an array of landmark data returns the Generalized Procrustes Analysis mean
    inputs: a 3 dimensional array of samples by landmarks by coordinate values, number of landmarks, number of dimensions
    output: an array of the Generalized Procrustes Analysis mean shape
    """
    ref_ind = 0
    ref_shape = leaf_arr[ref_ind, :, :]
    mean_diff = 10**(-30)
    old_mean = ref_shape
    d = 1000000

    while d > mean_diff:
        arr = np.zeros(((len(leaf_arr)), landmark_num, dim_num))
        for i in range(len(leaf_arr)):
            s1, s2, distance = procrustes(old_mean, leaf_arr[i])
            arr[i] = s2
        new_mean = np.mean(arr, axis=(0))
        s1, s2, d = procrustes(old_mean, new_mean)
        old_mean = new_mean
    return new_mean

def run_morphometric_analysis(metadata_file_path, image_data_dir, output_base_dir, dataset_name):
    """
    Runs the full morphometric analysis pipeline for a given dataset.

    Args:
        metadata_file_path (str): Path to the CSV metadata file.
        image_data_dir (str): Path to the directory containing image files.
        output_base_dir (str): Base directory where all outputs for this dataset will be saved.
        dataset_name (str): A descriptive name for the dataset (e.g., "plant_predict")
                            used in print statements and specific output filenames.
    """
    print(f"\n{'='*10} Starting Analysis for {dataset_name.upper()} Dataset {'='*10}")

    # --- Configuration and Inputs (now passed as arguments or derived) ---
    # Parameters for Preprocessing
    HIGH_RES_INTERPOLATION_POINTS = 10000
    FINAL_PSEUDO_LANDMARKS_PER_SIDE = 50
    NUM_LANDMARKS = (FINAL_PSEUDO_LANDMARKS_PER_SIDE * 2) - 1
    NUM_DIMENSIONS = 2

    # Parameters for Morphospace Visualization (2-Component PCA) - mostly commented out, but parameters remain if needed
    # MORPHOSPACE_PLOT_LENGTH = 10
    # MORPHOSPACE_PLOT_WIDTH = 10
    # MORPHOSPACE_PC1_INTERVALS = 20
    # MORPHOSPACE_PC2_INTERVALS = 6
    MORPHOSPACE_HUE_COLUMN = "plantID" # !!! CHANGED to plantID !!!
    # EIGENLEAF_SCALE = 0.08
    # EIGENLEAF_COLOR = "lightgray"
    # EIGENLEAF_ALPHA = 0.5
    # POINT_SIZE = 80
    # POINT_LINEWIDTH = 0
    # POINT_ALPHA = 0.6
    # AXIS_LABEL_FONTSIZE = 12
    # AXIS_TICK_FONTSIZE = 8
    # FACE_COLOR = "white"
    # GRID_ALPHA = 0.5

    # Parameters for Output Files
    GPA_MEAN_SHAPE_PLOT_FILENAME = f"gpa_mean_shape_{dataset_name}.png"
    PCA_EXPLAINED_VARIANCE_REPORT_FILENAME = f"pca_explained_variance_{dataset_name}.txt"
    MORPHOSPACE_PLOT_FILENAME = f"morphospace_plot_{dataset_name}.png" # Retained filename, but plot is not generated
    PCA_PARAMS_H5_FILENAME = f"leaf_pca_model_parameters_{dataset_name}.h5"
    ORIGINAL_PCA_SCORES_AND_LABELS_H5_FILENAME = f"original_pca_scores_and_class_labels_{dataset_name}.h5"
    CLASS_LABEL_COLUMN_FOR_SAVING = "plantID" # !!! CHANGED to plantID !!!

    os.makedirs(output_base_dir, exist_ok=True)
    print(f"Saving outputs to directory: {output_base_dir}")

    # --- Read in Metadata ---
    mdata = pd.read_csv(metadata_file_path)
    print(f"Metadata loaded from: {metadata_file_path}")
    print("First 5 rows of loaded metadata:")
    print(mdata.head())

    # --- Print plantID counts and structure ---
    print("\n--- Plant ID Class Information ---")
    if 'plantID' in mdata.columns and 'full_name' in mdata.columns:
        print("\nNumber of different plantID classes for each full_name (variety):")
        plant_id_counts_per_variety = mdata.groupby('full_name')['plantID'].nunique()
        print(plant_id_counts_per_variety.sort_index().to_string())

        print(f"\nTotal number of unique plantID classes: {mdata['plantID'].nunique()}")
        print("\nCounts of each plantID class:")
        plant_id_counts_overall = mdata['plantID'].value_counts().sort_index()
        print(plant_id_counts_overall.to_string())
    else:
        print("Warning: 'plantID' or 'full_name' column not found in metadata. Cannot print class information.")
    print("-----------------------------------")


    # --- Make a list of image file names ---
    file_names = mdata['file'].tolist()
    file_names.sort()
    print(f"Found {len(file_names)} image files to process from metadata.")

    # --- Interpolate Points Creating Pseudo-Landmarks and Pre-process ---
    print("\n--- Preprocessing Images and Interpolating Pseudo-Landmarks ---")
    
    # Filter out rows with missing image files *before* the loop to avoid incomplete arrays
    existing_image_files = set(f for f in os.listdir(image_data_dir) if os.path.isfile(os.path.join(image_data_dir, f)))
    
    valid_rows_indices = []
    processed_points_list = []    

    for lf_idx, row in mdata.iterrows():
        curr_image_filename = row["file"]
        img_path = os.path.join(image_data_dir, curr_image_filename)

        if curr_image_filename not in existing_image_files:
            print(f"Warning: Image file not found at {img_path}. Skipping and will exclude from analysis.")
            continue

        try:
            img = cv2.bitwise_not(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2GRAY))
            contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            x_conts = []
            y_conts = []
            areas_conts = []
            for c in contours:
                x_vals = [pt[0][0] for pt in c]
                y_vals = [pt[0][1] for pt in c]
                area = (max(x_vals) - min(x_vals)) * (max(y_vals) - min(y_vals))
                x_conts.append(x_vals)
                y_conts.append(y_vals)
                areas_conts.append(area)

            if not areas_conts:
                print(f"Warning: No contours found for image {curr_image_filename}. Skipping.")
                continue

            area_inds = np.flip(np.argsort(areas_conts))
            sorted_x_conts = np.array(x_conts, dtype=object)[area_inds][0:]
            sorted_y_conts = np.array(y_conts, dtype=object)[area_inds][0:]

            high_res_x, high_res_y = interpolation(np.array(sorted_x_conts[0], dtype=np.float32),
                                                   np.array(sorted_y_conts[0], dtype=np.float32), HIGH_RES_INTERPOLATION_POINTS)

            base_pt = np.array((row["base_x"], row["base_y"]))
            tip_pt = np.array((row["tip_x"], row["tip_y"]))

            base_dists = [euclid_dist(base_pt[0], base_pt[1], high_res_x[pt], high_res_y[pt]) for pt in range(len(high_res_x))]
            tip_dists = [euclid_dist(tip_pt[0], tip_pt[1], high_res_x[pt], high_res_y[pt]) for pt in range(len(high_res_x))]

            base_ind = np.argmin(base_dists)
            tip_ind = np.argmin(tip_dists)

            high_res_x = np.concatenate((high_res_x[base_ind:], high_res_x[:base_ind]))
            high_res_y = np.concatenate((high_res_y[base_ind:], high_res_y[:base_ind]))

            new_tip_dists = [euclid_dist(tip_pt[0], tip_pt[1], high_res_x[pt_idx], high_res_y[pt_idx]) for pt_idx in range(len(high_res_x))]
            tip_ind_new = np.argmin(new_tip_dists)

            lf_contour = np.column_stack((high_res_x, high_res_y))

            left_segment = lf_contour[0:tip_ind_new + 1, :]
            right_segment = np.concatenate((lf_contour[tip_ind_new:, :], lf_contour[0:1, :]), axis=0)

            if len(left_segment) < 2 or len(right_segment) < 2:
                print(f"Warning: Segments for image {curr_image_filename} are too short for interpolation. Skipping.")
                continue

            left_inter_x, left_inter_y = interpolation(left_segment[:, 0], left_segment[:, 1], FINAL_PSEUDO_LANDMARKS_PER_SIDE)
            right_inter_x, right_inter_y = interpolation(right_segment[:, 0], right_segment[:, 1], FINAL_PSEUDO_LANDMARKS_PER_SIDE)

            left_inter_x = np.delete(left_inter_x, -1)
            left_inter_y = np.delete(left_inter_y, -1)

            lf_pts_left = np.column_stack((left_inter_x, left_inter_y))
            lf_pts_right = np.column_stack((right_inter_x, right_inter_y))
            lf_pts = np.row_stack((lf_pts_left, lf_pts_right))

            if lf_pts.shape[0] != NUM_LANDMARKS:
                print(f"Warning: Leaf {curr_image_filename} generated {lf_pts.shape[0]} landmarks, expected {NUM_LANDMARKS}. Check interpolation logic.")
                continue

            tip_point = lf_pts[FINAL_PSEUDO_LANDMARKS_PER_SIDE - 1, :]
            base_point = lf_pts[0, :]

            ang = angle_between(tip_point, base_point, (base_point[0] + 1, base_point[1]))

            rot_x, rot_y = rotate_points(lf_pts[:, 0], lf_pts[:, 1], ang)
            rot_pts = np.column_stack((rot_x, rot_y))

            processed_points_list.append(rot_pts)
            valid_rows_indices.append(lf_idx)

        except Exception as e:
            print(f"Error processing image {curr_image_filename}: {e}. Skipping.")
            continue
    
    # Rebuild mdata and cult_cm_arr with only successfully processed images
    mdata = mdata.iloc[valid_rows_indices].reset_index(drop=True)
    cult_cm_arr = np.array(processed_points_list)

    if cult_cm_arr.shape[0] == 0:
        print(f"No valid images processed for {dataset_name}. Exiting analysis for this dataset.")
        return

    # --- Calculate GPA Mean ---
    print("--- Calculating GPA Mean ---")
    mean_shape = gpa_mean(cult_cm_arr, NUM_LANDMARKS, NUM_DIMENSIONS)

    # --- Align Leaves to GPA Mean ---
    print("--- Aligning Leaves to GPA Mean ---")
    proc_arr = np.zeros(np.shape(cult_cm_arr))
    for i in range(len(cult_cm_arr)):
        s1, s2, distance = procrustes(mean_shape, cult_cm_arr[i, :, :])
        s2_normalized = s2 # Procrustes already scales and translates, but often further normalization to unit centroid size is done if not handled by procrustes. For now, keep as is.
        proc_arr[i] = s2_normalized

    # --- Visualize GPA Aligned Shapes and Mean ---
    print("--- Visualizing GPA Aligned Shapes ---")
    plt.figure(figsize=(8, 8))
    for i in range(len(proc_arr)):
        plt.plot(proc_arr[i, :, 0], proc_arr[i, :, 1], c="k", alpha=0.08)
    plt.plot(np.mean(proc_arr, axis=0)[:, 0], np.mean(proc_arr, axis=0)[:, 1], c="magenta")
    plt.gca().set_aspect("equal")
    plt.axis("off")
    plt.title(f"Procrustes Aligned Leaf Shapes and GPA Mean ({dataset_name.replace('_', ' ').title()})")
    plt.savefig(os.path.join(output_base_dir, GPA_MEAN_SHAPE_PLOT_FILENAME))
    plt.close()
    print(f"GPA mean shape plot saved to {os.path.join(output_base_dir, GPA_MEAN_SHAPE_PLOT_FILENAME)}")

    # --- Calculate Percent Variance All PCs ---
    print("\n--- Performing Full PCA and Generating Explained Variance Report ---")
    flat_arr = proc_arr.reshape(np.shape(proc_arr)[0], np.shape(proc_arr)[1] * np.shape(proc_arr)[2])

    max_pc_components = min(flat_arr.shape[0], flat_arr.shape[1])
    pca = PCA(n_components=max_pc_components)
    PCs = pca.fit_transform(flat_arr)

    pca_explained_variance_filepath = os.path.join(output_base_dir, PCA_EXPLAINED_VARIANCE_REPORT_FILENAME)
    with open(pca_explained_variance_filepath, 'w') as f:
        f.write(f"PCA Explained Variance Report ({dataset_name.replace('_', ' ').title()} Dataset):\n")
        f.write(f"Total Samples: {flat_arr.shape[0]}\n")
        f.write(f"Total Features (landmarks * dimensions): {flat_arr.shape[1]}\n")
        f.write(f"Number of PCs Calculated: {pca.n_components_}\n\n")
        f.write("PC: var, overall\n")
        for i in range(len(pca.explained_variance_ratio_)):
            pc_variance = round(pca.explained_variance_ratio_[i] * 100, 2)
            cumulative_variance = round(pca.explained_variance_ratio_.cumsum()[i] * 100, 2)
            line = f"PC{i+1}: {pc_variance}%, {cumulative_variance}%\n"
            print(line.strip())
            f.write(line)
    print(f"PCA explained variance report saved to {pca_explained_variance_filepath}")

    # --- Save PCA Model Parameters, PC Scores, and Class Labels ---
    print("\n--- Saving PCA model parameters, PC scores, and class labels ---")
    pca_components = pca.components_
    pca_mean = pca.mean_
    pca_explained_variance = pca.explained_variance_
    pca_explained_variance_ratio = pca.explained_variance_ratio_
    n_pca_components = pca.n_components_

    print(f"  PCA Components shape: {pca_components.shape}")
    print(f"  PCA Mean shape: {pca_mean.shape}")
    print(f"  PCA Explained Variance shape: {pca_explained_variance.shape}")
    print(f"  PCA Explained Variance Ratio shape: {pca_explained_variance_ratio.shape}")
    print(f"  Number of PCA components: {n_pca_components}")
    print(f"  Original PCA Scores (PCs) shape: {PCs.shape}")
    print(f"  Class Labels ({CLASS_LABEL_COLUMN_FOR_SAVING}) length: {len(mdata[CLASS_LABEL_COLUMN_FOR_SAVING])}")

    pca_params_filepath = os.path.join(output_base_dir, PCA_PARAMS_H5_FILENAME)
    with h5py.File(pca_params_filepath, 'w') as f:
        f.create_dataset('components', data=pca_components, compression="gzip")
        f.create_dataset('mean', data=pca_mean, compression="gzip")
        f.create_dataset('explained_variance', data=pca_explained_variance, compression="gzip")
        f.create_dataset('explained_variance_ratio', data=pca_explained_variance_ratio, compression="gzip")
        f.attrs['n_components'] = n_pca_components
    print(f"PCA parameters saved to {pca_params_filepath}")

    pca_scores_labels_filepath = os.path.join(output_base_dir, ORIGINAL_PCA_SCORES_AND_LABELS_H5_FILENAME)
    with h5py.File(pca_scores_labels_filepath, 'w') as f:
        f.create_dataset('pca_scores', data=PCs, compression="gzip")
        # Ensure class labels are stored as variable-length strings ('S')
        f.create_dataset('class_labels', data=np.array(mdata[CLASS_LABEL_COLUMN_FOR_SAVING]).astype('S'), compression="gzip")
        f.create_dataset('original_flattened_coords', data=flat_arr, compression="gzip")
    print(f"Original PCA scores, class labels, AND original flattened coordinates saved to {pca_scores_labels_filepath}")

    # --- Create Morphospace (COMMENTED OUT as requested) ---
    # print("\n--- Creating Morphospace Plot ---")
    # morphospace_pca = PCA(n_components=2)
    # morphospace_PCs = morphospace_pca.fit_transform(flat_arr)

    # mdata["PC1"] = morphospace_PCs[:, 0]
    # mdata["PC2"] = morphospace_PCs[:, 1]

    # plt.figure(figsize=(MORPHOSPACE_PLOT_LENGTH, MORPHOSPACE_PLOT_WIDTH))
    # plt.gca().set_facecolor(FACE_COLOR)
    # plt.gca().set_axisbelow(True)

    # PC1_vals = np.linspace(np.min(PCs[:, 0]), np.max(PCs[:, 0]), MORPHOSPACE_PC1_INTERVALS)
    # PC2_vals = np.linspace(np.min(PCs[:, 1]), np.max(PCs[:, 1]), MORPHOSPACE_PC2_INTERVALS)

    # for i in PC1_vals:
    #     for j in PC2_vals:
    #         inv_leaf = morphospace_pca.inverse_transform(np.array([i, j]))
    #         inv_leaf_coords = inv_leaf.reshape(NUM_LANDMARKS, NUM_DIMENSIONS)

    #         inv_x = inv_leaf_coords[:, 0]
    #         inv_y = inv_leaf_coords[:, 1]

    #         plt.fill(inv_x * EIGENLEAF_SCALE + i, inv_y * EIGENLEAF_SCALE + j,
    #                  c=EIGENLEAF_COLOR, alpha=EIGENLEAF_ALPHA)

    # sns.scatterplot(data=mdata, x="PC1", y="PC2", hue=MORPHOSPACE_HUE_COLUMN,
    #                 s=POINT_SIZE, linewidth=POINT_LINEWIDTH, alpha=POINT_ALPHA)

    # plt.legend(bbox_to_anchor=(1.00, 1.02), prop={'size': 8.9})
    # xlab = f"PC1, {round(pca.explained_variance_ratio_[0] * 100, 1)}%"
    # ylab = f"PC2, {round(pca.explained_variance_ratio_[1] * 100, 1)}%"
    # plt.xlabel(xlab, fontsize=AXIS_LABEL_FONTSIZE)
    # plt.ylabel(ylab, fontsize=AXIS_LABEL_FONTSIZE)
    # plt.xticks(fontsize=AXIS_TICK_FONTSIZE)
    # plt.yticks(fontsize=AXIS_TICK_FONTSIZE)
    # plt.gca().set_aspect("equal")

    # plt.savefig(os.path.join(output_base_dir, MORPHOSPACE_PLOT_FILENAME), bbox_inches='tight')
    # plt.close()
    # print(f"Morphospace plot saved to {os.path.join(output_base_dir, MORPHOSPACE_PLOT_FILENAME)}")

    print(f"\n{'='*10} Analysis for {dataset_name.upper()} Dataset Completed {'='*10}")


# --- Define paths relative to the current notebook (COCA_PROJECT/data/PLANTPREDICT/) ---
# Paths for the plant prediction dataset
PLANT_PREDICT_METADATA_FILE = "./01_cultivated2nd_landmarks.csv" # In the same directory as the script
PLANT_PREDICT_IMAGE_DIR = "../../data/CULTIVATED2ND/00_cultivated2nd_data/" # Up two levels to COCA_PROJECT, then into data/CULTIVATED2ND/
PLANT_PREDICT_OUTPUT_DIR = "./03_morphometrics_output_plant_predict/" # New unique output dir within PLANTPREDICT

# --- Run analysis for the plant prediction dataset ---
run_morphometric_analysis(
    metadata_file_path=PLANT_PREDICT_METADATA_FILE,
    image_data_dir=PLANT_PREDICT_IMAGE_DIR,
    output_base_dir=PLANT_PREDICT_OUTPUT_DIR,
    dataset_name="plant_predict" # Descriptive name for this analysis
)

print("\nAll analyses for the plant prediction dataset completed and outputs saved.")


Saving outputs to directory: ./03_morphometrics_output_plant_predict/
Metadata loaded from: ./01_cultivated2nd_landmarks.csv
First 5 rows of loaded metadata:
          file            species  px_cm  base_x  base_y   tip_x   tip_y  \
0  AMA1A_a.tif  Erythroxylym coca   28.4  271.00  195.25  127.50  189.00   
1  AMA1A_b.tif  Erythroxylym coca   28.4  275.00  202.75  125.00  193.25   
2  AMA1A_c.tif  Erythroxylym coca   28.4  271.25  202.50  127.75  194.00   
3  AMA1A_d.tif  Erythroxylym coca   28.4  281.00  208.00  118.00  193.00   
4  AMA1A_e.tif  Erythroxylym coca   28.4  284.25  210.25  115.50  200.00   

  variety plant leaf full_name  type plantID  
0     AMA    1A    a   amazona  coca  AMA_1A  
1     AMA    1A    b   amazona  coca  AMA_1A  
2     AMA    1A    c   amazona  coca  AMA_1A  
3     AMA    1A    d   amazona  coca  AMA_1A  
4     AMA    1A    e   amazona  coca  AMA_1A  

--- Plant ID Class Information ---

Number of different plantID classes for each full_name (variety):