In [89]:
import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

data = pd.read_csv("./filter_overexpressed_amps_ob.tsv", delimiter="\t", index_col=0)

samples_cols = ["O 1", "O 2", "O 3", "OMS 1", "OMS 2", "OMS 3"]
data[samples_cols] = data[samples_cols].apply(pd.to_numeric, errors='coerce')

# print(data)
my_colors = ["orange", "red"]
sns.set_palette(my_colors)

# Origin colors
data['Produced by'] = data['origin'].astype('category')
origin = data.pop("Produced by")
origin_colors = ["#D8BFD8", "#98FB98", "#FFFACD", "#FFA07A"]  # Pastel purple, pastel green, pastel yellow, pastel red
origin_unique_values = origin.unique()[:len(origin_colors)]
origin_lut = dict(zip(origin_unique_values, origin_colors))
row_colors = origin.map(origin_lut)

# Group colors
data["Group"] = data["overexpressed_group"].astype('category')
group = data.pop("Group")
group_colors = ["orange", "red"]
group_unique_values = group.unique()[:len(group_colors)]
group_lut = dict(zip(group_unique_values, group_colors))
col_colors = group.map(group_lut)

# Combined colors
combined_row_colors = pd.concat([row_colors, col_colors], axis=1)

heatmap = sns.clustermap(
    np.log10(data[samples_cols]),
    row_colors=row_colors,
    row_cluster=False,
    col_cluster=False,
    cmap="BuPu",
    cbar_kws={"label": "log10(Abundance)"},
    figsize=(7.5, 7.5)
)

# handles_origin = [Patch(facecolor=lut[name]) for name in lut]
# handles_group = [Patch(facecolor=lut[name]) for name in lut]
# plt.legend(
#     handles_origin + handles_group,
#     list(lut.keys()) + list(lut.keys()),  # Adjust labels accordingly
#     title="Legend",
#     bbox_to_anchor=(0.1, 0., 0.09, 0.91),
#     bbox_transform=plt.gcf().transFigure,
#     loc='upper right'
# )

handles_origin = [Patch(facecolor=color, label=label) for label, color in origin_lut.items()]
# handles_group = [Patch(facecolor=color, label=label) for label, color in group_lut.items()]
plt.legend(
    handles_origin,
    [*origin_lut.keys(), *group_lut.keys()],
    title="Legend",
    bbox_to_anchor=(-5, 0, 1, 5.2),
    loc='upper left',
    borderaxespad=0,
)

heatmap.ax_cbar.set_position((0.05, 0.05, 0.01, 0.17))

heatmap.ax_heatmap.set_xlabel("")
heatmap.ax_heatmap.set_ylabel("")

# plt.tight_layout()
plt.savefig("./clustermap_log10_ob_groups_filtered.svg")
plt.savefig("./clustermap_log10_ob_groups_filtered.png")
plt.clf()


<Figure size 750x750 with 0 Axes>