In [19]:
# 1. Import libraries and src modules
import sys
import os
sys.path.append(os.path.abspath('../src'))

# Now, proceed with the rest of your imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from matplotlib import gridspec as mgs
from adjustText import adjust_text
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import silhouette_score

from src.clustering.config import *
from src.clustering.preprocessing import load_and_select_numeric, fill_and_standardize, run_pca
from src.clustering.clustering import (
    hierarchical_linkage, find_elbow, silhouette_optimal_clusters, elbow_method_for_hierarchical_clustering
)
from src.clustering.plotting import plot_pca_variance_panels, plot_dendrogram_panels, add_legend_to_axis

# =============================
# Input labels and paths
# =============================
labels_dict = {
    '../data/3_df_wide_Well-watered_Winter_2022-2023.csv': 'A',
    '../data/3_df_wide_Well-watered_combined_seasons.csv': 'B',
    '../data/3_df_wide_Well-watered_Winter_2022-2023_noYield.csv': 'C',
    '../data/3_df_wide_Well-watered_combined_seasons_noYield.csv': 'D'
}
paths = list(labels_dict.keys())
labels = list(labels_dict.values())

# =============================
# Create Figure and Grid
# =============================
fig = plt.figure(figsize=(23, 13))
gs = mgs.GridSpec(2, 2, width_ratios=[2, 4], height_ratios=[1, 1], wspace=0.30, hspace=0.4, figure=fig)

# =============================
# Panel 1 — PCA Explained Variance (2x2)
# =============================
nrows, ncols = 2, 2
gs_elbow = mgs.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=gs[0, 0], hspace=0.25, wspace=0.1)
elbow_axes, explained_var_list, elbow_points = plot_pca_variance_panels(fig, gs_elbow, labels, paths)

# =============================
# Panel 2 — Dendrograms (2x2)
# =============================
gs_dendo = mgs.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=gs[1, 0], hspace=0.55, wspace=0.1)

all_linkage_matrices = []
all_distances_silhouette = {}
all_distances_wcss = {}
entry_lists = []

for i, (label, path) in enumerate(zip(labels, paths)):
    df, df_numeric = load_and_select_numeric(path)
    data_scaled = fill_and_standardize(df_numeric)
    pca, explained_variance = run_pca(data_scaled)
    elbow_point = find_elbow(explained_variance)
    pca = PCA(n_components=elbow_point)
    principal_components = pca.fit_transform(data_scaled)
    linked = hierarchical_linkage(principal_components)

    optimal_clusters, sil_scores = silhouette_optimal_clusters(principal_components)
    distances = linked[:, 2]
    sorted_distances = sorted(distances, reverse=True)
    threshold_silhouette = (
        (sorted_distances[optimal_clusters - 1] + sorted_distances[optimal_clusters - 2]) / 2
    )
    all_linkage_matrices.append(linked)
    all_distances_silhouette[label] = (optimal_clusters, threshold_silhouette)

    wcss = elbow_method_for_hierarchical_clustering(principal_components, linked)
    optimal_clusters_wcss = find_elbow(wcss)
    threshold_wcss = (
        (sorted_distances[optimal_clusters_wcss - 1] + sorted_distances[optimal_clusters_wcss - 2]) / 2
    )
    all_distances_wcss[label] = (optimal_clusters_wcss, threshold_wcss)
    entry_lists.append(df['entry'].tolist())

palette = sns.color_palette('twilight_shifted', n_colors=optimal_clusters)
dendo_axes = plot_dendrogram_panels(
    fig, gs_dendo, labels, paths, all_linkage_matrices,
    all_distances_silhouette, all_distances_wcss, palette, entry_lists
)
add_legend_to_axis(dendo_axes[0])

# =============================
# Panel 3 — PCA Cluster Scatter Plots (2x2)
# =============================

# ---- Manual offsets (customize for your entries) ---
manual_offsets = {
    'A': { 1: (0.5, -0.1), 2: (-0.8, -0.1), 3: (0.5, -0.1), 4: (-0.15, -0.8), 5: (0.6, -0.1), 6: (0.5, -0.1),
           7: (-0.8, -0.1), 8: (-0.8, -0.1), 9: (-0.1, -0.8), 10: (0.5, -0.1), 11: (-0.3, -0.8), 12: (0.5, -0.1),
           13: (-0.5, -0.8), 14: (-0.3, 0.5) },
    'B': { 1: (-0.15, -0.9), 2: (0.6, -0.3), 3: (0.6, -0.3), 4: (-0.5, 0.5), 5: (-0.5, 0.5), 6: (0.5, -0.1),
           7: (-0.1, -0.9), 8: (-0.8, -0.1), 9: (-0.15, -0.9), 10: (0.6, -0.2), 11: (-0.3, 0.5), 12: (0.6, -0.2),
           13: (-0.5, 0.5), 14: (-0.5, -1.0) },
    'C': { 1: (0.5, -0.1), 2: (-0.2, -0.8), 3: (-0.1, -0.7), 4: (-0.2, -0.8), 5: (-0.3, 0.5), 6: (-0.2, -0.8),
           7: (0.5, -0.1), 8: (-0.8, -0.1), 9: (-0.3, 0.5), 10: (0.5, -0.1), 11: (-0.2, -0.7), 12: (-0.3, 0.5),
           13: (-0.3, 0.5), 14: (-0.5, -0.8) },
    'D': { 1: (0.5, -0.2), 2: (-0.2, 0.5), 3: (-0.5, -0.8), 4: (-0.2, -0.8), 5: (-0.2, 0.3), 6: (0.5, -0.2),
           7: (-0.8, -0.1), 8: (-0.2, 0.5), 9: (-0.2, -0.8), 10: (-0.4, 0.5), 11: (0.6, -0.3), 12: (-0.4, -0.8),
           13: (-0.4, 0.5), 14: (-1.1, -0.1) }
}

gs_scatter = mgs.GridSpecFromSubplotSpec(
    nrows, ncols, subplot_spec=gs[:, 1], hspace=0.2, wspace=0.2
)
scatter_axes = []
for i in range(nrows):
    for j in range(ncols):
        if i == 0 and j == 0:
            ax = fig.add_subplot(gs_scatter[i, j])
            ref_ax = ax
        else:
            ax = fig.add_subplot(gs_scatter[i, j], sharex=ref_ax, sharey=ref_ax)
        scatter_axes.append(ax)

all_handles = []
all_labels = []
cluster_dfs = []

for k, (label, path) in enumerate(zip(labels, paths)):
    df, df_numeric = load_and_select_numeric(path)
    data_scaled = fill_and_standardize(df_numeric)
    pca, explained_variance = run_pca(data_scaled)
    elbow_point = find_elbow(explained_variance)
    pca = PCA(n_components=elbow_point)
    principal_components = pca.fit_transform(data_scaled)

    # Silhouette optimal number of clusters
    silhouette_scores = []
    max_clusters = min(11, len(principal_components))
    range_n_clusters = range(2, max_clusters)
    for n_clusters in range_n_clusters:
        clusterer = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
        labels_pca = clusterer.fit_predict(principal_components)
        score = silhouette_score(principal_components, labels_pca)
        silhouette_scores.append(score)
    optimal_clusters = range_n_clusters[np.argmax(silhouette_scores)]

    # Apply KMeans clustering with this k
    kmeans = KMeans(n_clusters=optimal_clusters, random_state=42)
    kmeans_labels = kmeans.fit_predict(principal_components)
    pc_df = pd.DataFrame(
        principal_components,
        columns=[f'PC{i+1}' for i in range(elbow_point)]
    )
    pc_df['genotype'] = df['genotype'].values
    pc_df['entry'] = df['entry'].values
    pc_df['Cluster'] = kmeans_labels

    ax = scatter_axes[k]
    palette = sns.color_palette('twilight_shifted', n_colors=optimal_clusters)
    texts = []
    x_coords = []
    y_coords = []
    handles = []
    labels_ = []
    for cluster in range(optimal_clusters):
        cluster_points = pc_df[pc_df['Cluster'] == cluster]
        sc = ax.scatter(
            cluster_points['PC1'],
            cluster_points['PC2'],
            marker=marker_styles[cluster % len(marker_styles)],
            s=marker_size_pca,
            c=[palette[cluster % len(palette)]],
            edgecolor=edgecolor,
            linewidth=line_width_pca
        )
        if f'Cluster {cluster + 1}' not in labels_:
            handles.append(sc)
            labels_.append(f'Cluster {cluster + 1}')
        cluster_points = cluster_points.copy()
        cluster_points['Dataset'] = label
        cluster_dfs.append(cluster_points)
        for j, entry in enumerate(cluster_points['entry']):
            xi = cluster_points['PC1'].iloc[j]
            yi = cluster_points['PC2'].iloc[j]
            dx, dy = manual_offsets[label].get(entry, (0, 0)) if label in manual_offsets else (0, 0)
            text = ax.text(xi + dx, yi + dy, str(entry), fontsize=fontsize_big)
            texts.append(text)
            x_coords.append(xi)
            y_coords.append(yi)
    adjust_text(
        texts,
        target_x=x_coords,
        target_y=y_coords,
        ax=ax,
        avoid_self=False,
        force_text=0.1,
        force_points=0.1,
        arrowprops=dict(
            arrowstyle='-',
            color='gray',
            alpha=0.7,
            lw=1.2,
            shrinkA=2,
            shrinkB=2
        )
    )
    if k == 0 or len(labels_) > len(all_labels):
        all_handles = handles
        all_labels = labels_
    ax.set_title(label, fontsize=fontsize_big)
    ax.set_xlabel(f'PC1 ({explained_variance[0]:.2%} variance)', fontsize=fontsize_big, fontname=font_name)
    ax.set_ylabel(f'PC2 ({explained_variance[1]:.2%} variance)', fontsize=fontsize_big, fontname=font_name)
    ax.grid(True, linestyle=grid_linestyle, linewidth=grid_linewidth, alpha=grid_alpha)

for i in range(nrows):
    for j in range(ncols):
        ax = scatter_axes[i * ncols + j]
        if i != nrows - 1:
            ax.tick_params(labelbottom=False)
        if j != 0:
            ax.tick_params(labelleft=False)

scatter_axes[0].legend(
    handles=all_handles, 
    labels=all_labels, 
    loc='lower center', 
    bbox_to_anchor=(1, -1.58),
    ncol=(len(all_labels) + 1),
    frameon=True, 
    columnspacing=1.2,
    handletextpad=0.0,
    fontsize=fontsize_big,
    prop=legend_font,
    title_fontproperties=legend_title_font
)

# =============================
# Main Panel Titles and Save
# =============================
fig.text(0.10, 0.92, '(a)', fontsize=title_fontsize, fontname=font_name, weight='bold')
fig.text(0.10, 0.47, '(b)', fontsize=title_fontsize, fontname=font_name, weight='bold')
fig.text(0.43, 0.92, '(c)', fontsize=title_fontsize, fontname=font_name, weight='bold')

fig.savefig('../plots/WW_comparison/cluster_analysis_figure.svg', dpi=300, bbox_inches='tight')
plt.show()

ModuleNotFoundError: No module named 'src'