In [21]:
import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# Load and prepare data
data = pd.read_csv("./filter_overexpressed_amps_ob.tsv", delimiter="\t", index_col=0)
samples_cols = ["O 1", "O 2", "O 3", "OMS 1", "OMS 2", "OMS 3"]
data[samples_cols] = data[samples_cols].apply(pd.to_numeric, errors='coerce')

# Define custom palettes
my_colors = ["orange", "red"]
sns.set_palette(my_colors)

# Setup row colors for 'origin'
data['Origin'] = data['origin'].astype('category')
origin = data.pop("Origin")
origin_colors = ["#D8BFD8", "#98FB98", "#FFFACD", "#FFA07A"]
origin_unique_values = origin.unique()[:len(origin_colors)]
origin_lut = dict(zip(origin_unique_values, origin_colors))
row_colors_origin = origin.map(origin_lut)

# Setup column colors for 'group'
data["Group"] = data["overexpressed_group"].astype('category')
group = data.pop("Group")
group_colors = ["orange", "red"]
desired_group_order = ["O", "OMS"]
group_lut = dict(zip(desired_group_order, group_colors))
row_colors_group = group.map(group_lut)

# Combine row colors
combined_row_colors = pd.concat([row_colors_group, row_colors_origin], axis=1)

# Generate clustermap
heatmap = sns.clustermap(
    np.log10(data[samples_cols]),
    row_colors=combined_row_colors,
    row_cluster=False,
    col_cluster=False,
    cmap="BuPu",
    cbar_kws={"label": "log10(Abundance)"},
    figsize=(8.5, 6)
)

# Create legend handles
handles_origin = [Patch(facecolor=color, label=label) for label, color in origin_lut.items()]
handles_group = [Patch(facecolor=color, label=label) for label, color in group_lut.items()]

# Add legends manually to the figure
legend1 = heatmap.ax_heatmap.legend(
    handles_origin,
    origin_lut.keys(),
    title="Origin",
    loc='upper left',
    bbox_to_anchor=(-0.385, 0.6),
    borderaxespad=0,
    frameon=False
)

legend2 = heatmap.ax_heatmap.legend(
    handles_group,
    group_lut.keys(),
    title="Overexpressed\ngroup",
    loc='upper left',
    bbox_to_anchor=(-0.38, 1),
    borderaxespad=0,
    frameon=False
)

# Add both legends (keep the first one visible)
heatmap.ax_heatmap.add_artist(legend1)

# Final formatting
heatmap.ax_cbar.set_position((0.05, 0.1, 0.02, 0.17))
heatmap.ax_heatmap.set_xlabel("")
heatmap.ax_heatmap.set_ylabel("")

# Save output
plt.savefig("./clustermap_log10_ob_groups_filtered.svg")
plt.savefig("./clustermap_log10_ob_groups_filtered.png")
plt.clf()


<Figure size 850x600 with 0 Axes>