In [14]:
# Updated code to allow multiple ranks per cell if multiple types overlap at the same index
import os
import numpy as np
import matplotlib.pyplot as plt

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1", "hivprot", "nk1",
    "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes (experiment identifiers)
file_prefixes = [
    "S-results-dsets_mvalid-all-3real",
    "S-results-dsets_mvalid-exclude-all-3real",
    "S-results-strans_mvalid-all-3real",
    "S-results-strans_mvalid-exclude-all-3real"
]

# Base directory where *_top30_mv_pairs.txt files are stored
base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/final/top30_mv_pairs"

# Mapping color channel index to actual color
color_map = {
    0: "red",    # hivprot
    1: "green",  # dpp4
    2: "blue"    # nk1
}

# Type mapping and grouping
type_color_map = {
    "hivprot_count": 0,
    "dpp4_count": 1,
    "nk1_count": 2,
    "hivprot_bit": 0,
    "dpp4_bit": 1,
    "nk1_bit": 2
}
count_types = ["hivprot_count", "dpp4_count", "nk1_count"]
bit_types = ["hivprot_bit", "dpp4_bit", "nk1_bit"]

# Generate annotated matrices
annotated_outputs = []
for prefix in file_prefixes:
    for mode, types in zip(["count", "bit"], [count_types, bit_types]):
        matrix_rgb = np.zeros((15, 15, 3), dtype=float)
        annotations = {}  # {(x, y): [(rank, color), ...]}

        for t in types:
            filename = f"{t}_{prefix}_top30_mv_pairs.txt"
            path = os.path.join(base_dir, filename)
            if not os.path.exists(path):
                print(f"Missing: {path}")
                continue

            with open(path, "r") as f:
                lines = f.readlines()

            for i, line in enumerate(lines[:30]):
                parts = line.strip().split(",")
                if len(parts) != 2:
                    continue
                a, b = parts[0].strip(), parts[1].strip()
                if a in dataset_idx and b in dataset_idx:
                    x, y = dataset_idx[a], dataset_idx[b]
                    ch = type_color_map[t]
                    matrix_rgb[x, y, ch] = (30 - i) / 30.0
                    if (x, y) not in annotations:
                        annotations[(x, y)] = []
                    annotations[(x, y)].append((i + 1, color_map[ch]))

        # Plot with annotations
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(matrix_rgb)
        ax.set_title(f"{prefix} - {mode}")
        ax.set_xticks(np.arange(15))
        ax.set_yticks(np.arange(15))
        ax.set_xticklabels(dataset_names, rotation=90)
        ax.set_yticklabels(dataset_names)

        for (x, y), entries in annotations.items():
            label_text = "\n".join([str(rank) for rank, _ in entries])
            colors = [color for _, color in entries]
            color = colors[0] if len(set(colors)) == 1 else "white"
            ax.text(y, x, label_text, ha='center', va='center', color=color, fontsize=8, weight='bold')

        ax.grid(False)
        output_path = f"./matrix_{prefix}_{mode}_annotated.png"
        annotated_outputs.append(output_path)
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close(fig)

annotated_outputs


['./matrix_S-results-dsets_mvalid-all-3real_count_annotated.png',
 './matrix_S-results-dsets_mvalid-all-3real_bit_annotated.png',
 './matrix_S-results-dsets_mvalid-exclude-all-3real_count_annotated.png',
 './matrix_S-results-dsets_mvalid-exclude-all-3real_bit_annotated.png',
 './matrix_S-results-strans_mvalid-all-3real_count_annotated.png',
 './matrix_S-results-strans_mvalid-all-3real_bit_annotated.png',
 './matrix_S-results-strans_mvalid-exclude-all-3real_count_annotated.png',
 './matrix_S-results-strans_mvalid-exclude-all-3real_bit_annotated.png']

In [16]:
# Updated: Generate separate matrices for hivprot, dpp4, nk1 instead of combining all into one
import os
import numpy as np
import matplotlib.pyplot as plt

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1", "hivprot", "nk1",
    "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes (experiment identifiers)
file_prefixes = [
    "S-results-dsets_mvalid-all-3real",
    "S-results-dsets_mvalid-exclude-all-3real",
    "S-results-strans_mvalid-all-3real",
    "S-results-strans_mvalid-exclude-all-3real"
]

# Base directory where *_top30_mv_pairs.txt files are stored
base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/final/top30_mv_pairs"

# Model types organized
type_groups = {
    "hivprot": ["hivprot_count", "hivprot_bit"],
    "dpp4": ["dpp4_count", "dpp4_bit"],
    "nk1": ["nk1_count", "nk1_bit"]
}

# Color map for text (always use single color since only one model group at a time)
color_per_group = {
    "hivprot": "red",
    "dpp4": "green",
    "nk1": "blue"
}

# Generate matrices separately
annotated_outputs = []
for prefix in file_prefixes:
    for group_name, group_types in type_groups.items():
        matrix_rgb = np.zeros((15, 15), dtype=float)
        annotations = {}  # {(x, y): [(rank, color), ...]}

        for t in group_types:
            filename = f"{t}_{prefix}_top30_mv_pairs.txt"
            path = os.path.join(base_dir, filename)
            if not os.path.exists(path):
                print(f"Missing: {path}")
                continue

            with open(path, "r") as f:
                lines = f.readlines()

            for i, line in enumerate(lines[:30]):
                parts = line.strip().split(",")
                if len(parts) != 2:
                    continue
                a, b = parts[0].strip(), parts[1].strip()
                if a in dataset_idx and b in dataset_idx:
                    x, y = dataset_idx[a], dataset_idx[b]
                    matrix_rgb[x, y] = (30 - i) / 30.0
                    if (x, y) not in annotations:
                        annotations[(x, y)] = []
                    annotations[(x, y)].append((i + 1, color_per_group[group_name]))

        # Plot with annotations
        fig, ax = plt.subplots(figsize=(6, 6))
        cmap = plt.cm.Reds if group_name == "hivprot" else (plt.cm.Greens if group_name == "dpp4" else plt.cm.Blues)
        cax = ax.imshow(matrix_rgb, cmap=cmap, vmin=0, vmax=1)
        ax.set_title(f"{prefix} - {group_name}")
        ax.set_xticks(np.arange(15))
        ax.set_yticks(np.arange(15))
        ax.set_xticklabels(dataset_names, rotation=90)
        ax.set_yticklabels(dataset_names)

        for (x, y), entries in annotations.items():
            label_text = "\n".join([str(rank) for rank, _ in entries])
            ax.text(y, x, label_text, ha='center', va='center', color=color_per_group[group_name], fontsize=8, weight='bold')

        ax.grid(False)
        base = f"/c2/jinakim/Drug_Discovery_j/analysis/individual/{group_name}/"
        os.makedirs(base, exist_ok=True)

        output_path = base + f"{prefix}_separate.png"
        annotated_outputs.append(output_path)
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close(fig)

annotated_outputs


['/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-exclude-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-exclude-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-strans_mvalid-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-strans_mvalid-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-strans_mvalid-all-3real_separate.png',
 '/c2/jinakim/Drug_Discovery_j/

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1", "hivprot", "nk1",
    "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes (experiment identifiers)

= [
    "S-results-dsets_mvalid-all-3real",
    "S-results-dsets_mvalid-exclude-all-3real",
    "S-results-strans_mvalid-all-3real",
    "S-results-strans_mvalid-exclude-all-3real"
]

# Base directory where *_top30_mv_pairs.txt files are stored
base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/final/top30_mv_pairs"

# Model types organized
type_color = {
    "hivprot_count": ("hivprot", "count", "red"),
    "hivprot_bit": ("hivprot", "bit", "red"),
    "dpp4_count": ("dpp4", "count", "green"),
    "dpp4_bit": ("dpp4", "bit", "green"),
    "nk1_count": ("nk1", "count", "blue"),
    "nk1_bit": ("nk1", "bit", "blue")
}

# Generate matrices separately for count and bit
annotated_outputs = []
for prefix in file_prefixes:
    for group_name in ["hivprot", "dpp4", "nk1"]:
        for mode in ["count", "bit"]:
            matrix_rgb = np.zeros((15, 15), dtype=float)
            annotations = {}  # {(x, y): [(rank, color), ...]}

            for t, (group, type_, color) in type_color.items():
                if group == group_name and type_ == mode:
                    filename = f"{t}_{prefix}_top30_mv_pairs.txt"
                    path = os.path.join(base_dir, filename)
                    if not os.path.exists(path):
                        print(f"Missing: {path}")
                        continue

                    with open(path, "r") as f:
                        lines = f.readlines()

                    for i, line in enumerate(lines[:30]):
                        parts = line.strip().split(",")
                        if len(parts) != 2:
                            continue
                        a, b = parts[0].strip(), parts[1].strip()
                        if a in dataset_idx and b in dataset_idx:
                            x, y = dataset_idx[a], dataset_idx[b]
                            matrix_rgb[x, y] = (30 - i) / 30.0
                            if (x, y) not in annotations:
                                annotations[(x, y)] = []
                            annotations[(x, y)].append((i + 1, color))

            # Plot
            fig, ax = plt.subplots(figsize=(6, 6))
            cmap = plt.cm.Reds if group_name == "hivprot" else (plt.cm.Greens if group_name == "dpp4" else plt.cm.Blues)
            ax.imshow(matrix_rgb, cmap=cmap, vmin=0, vmax=1)
            ax.set_title(f"{prefix} - {group_name} - {mode}")
            ax.set_xticks(np.arange(15))
            ax.set_yticks(np.arange(15))
            ax.set_xticklabels(dataset_names, rotation=90)
            ax.set_yticklabels(dataset_names)

            for (x, y), entries in annotations.items():
                label_text = "\n".join([str(rank) for rank, _ in entries])
                ax.text(y, x, label_text, ha='center', va='center', color=color, fontsize=8, weight='bold')

            ax.grid(False)

            base = f"/c2/jinakim/Drug_Discovery_j/analysis/individual/{group_name}/"
            os.makedirs(base, exist_ok=True)

            output_path = base + f"{prefix}_separate_{mode}.png"
            annotated_outputs.append(output_path)
            plt.tight_layout()
            plt.savefig(output_path)
            plt.close(fig)

annotated_outputs


['/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-exclude-all-3real_separ

In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1", "hivprot", "nk1",
    "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes (experiment identifiers)
file_prefixes = [
    "S-results-dsets_mvalid-all-3real",
    "S-results-dsets_mvalid-exclude-all-3real",
    "S-results-strans_mvalid-all-3real",
    "S-results-strans_mvalid-exclude-all-3real"
]

# Base directory where *_top30_mv_pairs.txt files are stored
base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/final/top30_mv_pairs"

# Model types organized
type_color = {
    "hivprot_count": ("hivprot", "count", "red"),
    "hivprot_bit": ("hivprot", "bit", "red"),
    "dpp4_count": ("dpp4", "count", "green"),
    "dpp4_bit": ("dpp4", "bit", "green"),
    "nk1_count": ("nk1", "count", "blue"),
    "nk1_bit": ("nk1", "bit", "blue")
}

# Function to determine appropriate text color depending on background intensity
def get_text_color(val, cmap):
    rgba = cmap(val)
    luminance = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
    return 'black' if luminance > 0.5 else 'white'

# Generate matrices separately for each model type and mode
annotated_outputs = []
for prefix in file_prefixes:
    for group_name in ["hivprot", "dpp4", "nk1"]:
        for mode in ["count", "bit"]:
            matrix_rgb = np.zeros((15, 15), dtype=float)
            annotations = {}

            for t, (group, type_, color) in type_color.items():
                if group == group_name and type_ == mode:
                    filename = f"{t}_{prefix}_top30_mv_pairs.txt"
                    path = os.path.join(base_dir, filename)
                    if not os.path.exists(path):
                        print(f"Missing: {path}")
                        continue

                    with open(path, "r") as f:
                        lines = f.readlines()

                    for i, line in enumerate(lines[:30]):
                        parts = line.strip().split(",")
                        if len(parts) != 2:
                            continue
                        a, b = parts[0].strip(), parts[1].strip()
                        if a in dataset_idx and b in dataset_idx:
                            x, y = dataset_idx[a], dataset_idx[b]
                            matrix_rgb[x, y] = (30 - i) / 30.0
                            if (x, y) not in annotations:
                                annotations[(x, y)] = []
                            annotations[(x, y)].append((i + 1, color))

            # Plot
            fig, ax = plt.subplots(figsize=(7, 7))
            cmap = plt.cm.Reds if group_name == "hivprot" else (plt.cm.Greens if group_name == "dpp4" else plt.cm.Blues)
            im = ax.imshow(matrix_rgb, cmap=cmap, vmin=0, vmax=1)

            ax.set_title(f"{prefix} - {group_name} - {mode}", fontsize=14, fontweight='bold')
            ax.set_xticks(np.arange(15))
            ax.set_yticks(np.arange(15))
            ax.set_xticklabels(dataset_names, rotation=90, fontsize=10)
            ax.set_yticklabels(dataset_names, fontsize=10)

            # Draw minor grid (very light) to separate cells
            ax.set_xticks(np.arange(-0.5, 15, 1), minor=True)
            ax.set_yticks(np.arange(-0.5, 15, 1), minor=True)
            ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
            ax.tick_params(which='minor', bottom=False, left=False)

            # Fill in text
            for (x, y), entries in annotations.items():
                label_text = "\n".join([str(rank) for rank, _ in entries])
                text_color = get_text_color(matrix_rgb[x, y], cmap)
                ax.text(y, x, label_text, ha='center', va='center', color=text_color, fontsize=9, weight='bold')

            # Save
            output_dir = f"/c2/jinakim/Drug_Discovery_j/analysis/individual/{group_name}/"
            os.makedirs(output_dir, exist_ok=True)

            output_path = os.path.join(output_dir, f"{prefix}_separate_{mode}.png")
            annotated_outputs.append(output_path)
            plt.tight_layout()
            plt.savefig(output_path)
            plt.close(fig)

annotated_outputs


['/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate_count.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate_bit.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-exclude-all-3real_separ

In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1", "hivprot", "nk1",
    "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes
file_prefixes = [
    "S-results-dsets_mvalid-all-3real",
    "S-results-dsets_mvalid-exclude-all-3real",
    "S-results-strans_mvalid-all-3real",
    "S-results-strans_mvalid-exclude-all-3real"
]

# Base directory where *_top30_mv_pairs.txt files are stored
base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/final/top30_mv_pairs"

# Model types organized
type_color = {
    "hivprot_count": ("hivprot", "count", "red"),
    "hivprot_bit": ("hivprot", "bit", "red"),
    "dpp4_count": ("dpp4", "count", "green"),
    "dpp4_bit": ("dpp4", "bit", "green"),
    "nk1_count": ("nk1", "count", "blue"),
    "nk1_bit": ("nk1", "bit", "blue")
}

def get_text_color(val, cmap):
    rgba = cmap(val)
    luminance = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
    return 'black' if luminance > 0.5 else 'white'

annotated_outputs = []

for prefix in file_prefixes:
    for group_name in ["hivprot", "dpp4", "nk1"]:
        for mode in ["count", "bit"]:
            matrix_rgb = np.zeros((15, 15), dtype=float)
            annotations = {}

            for t, (group, type_, color) in type_color.items():
                if group == group_name and type_ == mode:
                    filename = f"{t}_{prefix}_top30_mv_pairs.txt"
                    path = os.path.join(base_dir, filename)
                    if not os.path.exists(path):
                        print(f"Missing: {path}")
                        continue

                    with open(path, "r") as f:
                        lines = f.readlines()

                    for i, line in enumerate(lines[:30]):
                        parts = line.strip().split(",")
                        if len(parts) != 2:
                            continue
                        a, b = parts[0].strip(), parts[1].strip()
                        if a in dataset_idx and b in dataset_idx:
                            x, y = dataset_idx[a], dataset_idx[b]
                            if x > y:
                                x, y = y, x  # ⭐ Flip order if in lower triangle
                            matrix_rgb[x, y] = (30 - i) / 30.0
                            if (x, y) not in annotations:
                                annotations[(x, y)] = []
                            annotations[(x, y)].append((i + 1, color))

            # Plot
            fig, ax = plt.subplots(figsize=(7, 7))
            cmap = plt.cm.Reds if group_name == "hivprot" else (plt.cm.Greens if group_name == "dpp4" else plt.cm.Blues)
            im = ax.imshow(matrix_rgb, cmap=cmap, vmin=0, vmax=1)

            ax.set_title(f"{prefix} - {group_name} - {mode}", fontsize=14, fontweight='bold')
            ax.set_xticks(np.arange(15))
            ax.set_yticks(np.arange(15))
            ax.set_xticklabels(dataset_names, rotation=90, fontsize=10)
            ax.set_yticklabels(dataset_names, fontsize=10)

            # Light minor grid
            ax.set_xticks(np.arange(-0.5, 15, 1), minor=True)
            ax.set_yticks(np.arange(-0.5, 15, 1), minor=True)
            ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
            ax.tick_params(which='minor', bottom=False, left=False)

            # Annotations
            for (x, y), entries in annotations.items():
                label_text = "\n".join([str(rank) for rank, _ in entries])
                text_color = get_text_color(matrix_rgb[x, y], cmap)
                ax.text(y, x, label_text, ha='center', va='center', color=text_color, fontsize=9, weight='bold')

            ax.grid(False)

            output_dir = f"/c2/jinakim/Drug_Discovery_j/analysis/individual/{group_name}/"
            os.makedirs(output_dir, exist_ok=True)

            # Save PNG and PDF
            base_filename = f"{prefix}_separate_{mode}_fliplower"
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{base_filename}.png"))
            plt.savefig(os.path.join(output_dir, f"{base_filename}.pdf"))  # ✅ PDF now with grid too
            plt.close(fig)

            annotated_outputs.append(os.path.join(output_dir, f"{base_filename}.png"))

annotated_outputs


['/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_count_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_bit_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_count_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_bit_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_count_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_bit_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate_count_fliplower.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-exclude-all-3real_separate_bit_fliplower.png',
 '/c2/jinakim/Drug_Dis

In [7]:
import os
import numpy as np
import matplotlib.pyplot as plt

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1", "hivprot", "nk1",
    "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes
file_prefixes = [
    "S-results-dsets_mvalid-all-3real",
    # "S-results-dsets_mvalid-exclude-all-3real",
    "S-results-strans_mvalid-all-3real",
    # "S-results-strans_mvalid-exclude-all-3real"
]

# Base directory where *_top30_mv_pairs.txt files are stored
base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/final/top30_mv_pairs"

# Model types organized
type_color = {
    "hivprot_count": ("hivprot", "count", "red"),
    "hivprot_bit": ("hivprot", "bit", "red"),
    "dpp4_count": ("dpp4", "count", "green"),
    "dpp4_bit": ("dpp4", "bit", "green"),
    "nk1_count": ("nk1", "count", "blue"),
    "nk1_bit": ("nk1", "bit", "blue")
}

def get_text_color(val, cmap):
    rgba = cmap(val)
    luminance = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
    return 'black' if luminance > 0.5 else 'white'

annotated_outputs = []

for prefix in file_prefixes:
    for group_name in ["hivprot", "dpp4", "nk1"]:
        for mode in ["count", "bit"]:
            matrix_rgb = np.zeros((15, 15), dtype=float)
            annotations = {}

            for t, (group, type_, color) in type_color.items():
                if group == group_name and type_ == mode:
                    filename = f"{t}_{prefix}_top30_mv_pairs.txt"
                    path = os.path.join(base_dir, filename)
                    if not os.path.exists(path):
                        print(f"Missing: {path}")
                        continue

                    with open(path, "r") as f:
                        lines = f.readlines()

                    for i, line in enumerate(lines[:30]):
                        parts = line.strip().split(",")
                        if len(parts) != 2:
                            continue
                        a, b = parts[0].strip(), parts[1].strip()
                        if a in dataset_idx and b in dataset_idx:
                            x, y = dataset_idx[a], dataset_idx[b]

                            # Fill both (x, y) and (y, x)
                            for xx, yy in [(x, y), (y, x)]:
                                matrix_rgb[xx, yy] = (30 - i) / 30.0
                                if (xx, yy) not in annotations:
                                    annotations[(xx, yy)] = []
                                annotations[(xx, yy)].append((i + 1, color))

            # Plot
            fig, ax = plt.subplots(figsize=(7, 7))
            cmap = plt.cm.Reds if group_name == "hivprot" else (plt.cm.Greens if group_name == "dpp4" else plt.cm.Blues)
            im = ax.imshow(matrix_rgb, cmap=cmap, vmin=0, vmax=1)

            ax.set_title(f"{prefix} - {group_name} - {mode}", fontsize=14, fontweight='bold')
            ax.set_xticks(np.arange(15))
            ax.set_yticks(np.arange(15))
            ax.set_xticklabels(dataset_names, rotation=90, fontsize=10)
            ax.set_yticklabels(dataset_names, fontsize=10)

            # Minor grid
            ax.set_xticks(np.arange(-0.5, 15, 1), minor=True)
            ax.set_yticks(np.arange(-0.5, 15, 1), minor=True)
            ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
            ax.tick_params(which='minor', bottom=False, left=False)

            # Annotations
            for (x, y), entries in annotations.items():
                label_text = "\n".join([str(rank) for rank, _ in entries])
                text_color = get_text_color(matrix_rgb[x, y], cmap)
                ax.text(y, x, label_text, ha='center', va='center', color=text_color, fontsize=9, weight='bold')

            ax.grid(False)

            output_dir = f"/c2/jinakim/Drug_Discovery_j/analysis/individual/{group_name}/"
            os.makedirs(output_dir, exist_ok=True)

            base_filename = f"{prefix}_separate_{mode}_symmetric"
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{base_filename}.png"))
            plt.savefig(os.path.join(output_dir, f"{base_filename}.pdf"))
            plt.close(fig)

            annotated_outputs.append(os.path.join(output_dir, f"{base_filename}.png"))

annotated_outputs


['/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_count_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-dsets_mvalid-all-3real_separate_bit_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_count_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/dpp4/S-results-dsets_mvalid-all-3real_separate_bit_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_count_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/nk1/S-results-dsets_mvalid-all-3real_separate_bit_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-strans_mvalid-all-3real_separate_count_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analysis/individual/hivprot/S-results-strans_mvalid-all-3real_separate_bit_symmetric.png',
 '/c2/jinakim/Drug_Discovery_j/analy

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import re

# Dataset names and index mapping
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1",
    "hivprot", "nk1", "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# File prefixes
file_prefixes = [
    "S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real",
    # "S-results-strans_mvalid-all-3real"
]

# Model types you want to separate
model_types = ["hivprot_count", "hivprot_bit", "dpp4_count"]

# Base directory where the experiment result files are stored
exp_base_dir = "/c2/jinakim/Drug_Discovery_j/experiments/"
# Save outputs
output_base_dir = "/c2/jinakim/Drug_Discovery_j/analysis/individual/mean_value_matrices_split_RYV1/"
os.makedirs(output_base_dir, exist_ok=True)

def get_text_color(val, cmap):
    rgba = cmap(val)
    luminance = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
    return 'black' if luminance > 0.5 else 'white'

for prefix in file_prefixes:
    exp_file_path = os.path.join(exp_base_dir, f"{prefix}.txt")
    if not os.path.exists(exp_file_path):
        print(f"Missing: {exp_file_path}")
        continue

    with open(exp_file_path, "r") as f:
        lines = f.readlines()

    # --- Initialize matrices ---
    matrices = {model: np.full((15, 15), np.nan) for model in model_types}

    current_model = None
    current_mv_pair = None

    # --- Parse the file ---
    for line in lines:
        # Check if line starts a new model block
        for model in model_types:
            if model in line:
                current_model = model
                mv_match = re.search(r"mv\s*:\s*(\[.*?\])", line)
                if mv_match:
                    mv_pair = eval(mv_match.group(1))
                    if len(mv_pair) == 2 and mv_pair[0] in dataset_idx and mv_pair[1] in dataset_idx:
                        current_mv_pair = (dataset_idx[mv_pair[0]], dataset_idx[mv_pair[1]])
                    else:
                        current_mv_pair = None
                break  # stop checking other models

        # Parse last performance mu value
        last_mu_match = re.search(r"last performance mu:\s*([0-9.eE+-]+)", line)
        if last_mu_match and current_model and current_mv_pair:
            mean_val = float(last_mu_match.group(1))
            x, y = current_mv_pair

            matrices[current_model][x, y] = mean_val
            matrices[current_model][y, x] = mean_val  # symmetric

    # --- Plotting ---
    for model, matrix in matrices.items():
        if np.isnan(matrix).all():
            print(f"Skipping {model} for {prefix} because matrix is all NaNs")
            continue

        fig, ax = plt.subplots(figsize=(7, 7))
        cmap = plt.cm.Greys_r

        im = ax.imshow(matrix, cmap=cmap, vmin=np.nanmin(matrix), vmax=np.nanmax(matrix))

        ax.set_title(f"{prefix} - {model} Mean Last Performance μ", fontsize=14, fontweight='bold')
        ax.set_xticks(np.arange(15))
        ax.set_yticks(np.arange(15))
        ax.set_xticklabels(dataset_names, rotation=90, fontsize=10)
        ax.set_yticklabels(dataset_names, fontsize=10)

        # Minor grid
        ax.set_xticks(np.arange(-0.5, 15, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, 15, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.5)
        ax.tick_params(which='minor', bottom=False, left=False)

        # Fill with values
        for i in range(15):
            for j in range(15):
                if not np.isnan(matrix[i, j]):
                    val = matrix[i, j]
                    normalized_val = (val - np.nanmin(matrix)) / (np.nanmax(matrix) - np.nanmin(matrix))
                    text_color = get_text_color(normalized_val, cmap)
                    truncated_val = str(val)[:str(val).find('.')+4]  # cut after 3 decimals (no rounding)
                    ax.text(j, i, truncated_val, ha='center', va='center', color=text_color, fontsize=7, weight='bold')

        output_dir = os.path.join(output_base_dir, model)
        os.makedirs(output_dir, exist_ok=True)

        output_path_png = os.path.join(output_dir, f"{prefix}_mean_matrix.png")
        output_path_pdf = os.path.join(output_dir, f"{prefix}_mean_matrix.pdf")

        plt.tight_layout()
        plt.savefig(output_path_png)
        plt.savefig(output_path_pdf)
        plt.close(fig)

        print(f"✅ Saved: {output_path_png}, {output_path_pdf}")


Skipping hivprot_count for S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real because matrix is all NaNs
Skipping hivprot_bit for S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real because matrix is all NaNs
Skipping dpp4_count for S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real because matrix is all NaNs


In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import re

# Dataset names
dataset_names = [
    "3a4", "logd", "hivint", "metab", "ox2", "rat_f", "cb1",
    "hivprot", "nk1", "pgp", "tdi", "dpp4", "ox1", "ppb", "thrombin"
]
dataset_idx = {name: i for i, name in enumerate(dataset_names)}

# Target model types
model_types = [
    "hivprot count", "hivprot bit",
    "dpp4 count", "dpp4 bit",
    "nk1 count", "nk1 bit"
]

# Files to read
exp_files = [
    "/c2/jinakim/Drug_Discovery_j/experiments/S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real.txt",
    # "/c2/jinakim/Drug_Discovery_j/experiments/S-results-strans_mvalid-all-3real.txt"
]

# Output base
output_base_dir = "/c2/jinakim/Drug_Discovery_j/analysis/individual/mean_value_matrices_split_color_RYV1"
os.makedirs(output_base_dir, exist_ok=True)

# Color map per group
group_cmaps = {
    "hivprot": plt.cm.Reds,
    "dpp4": plt.cm.Greens,
    "nk1": plt.cm.Blues
}

# Function to decide text color
def get_text_color(val, cmap):
    rgba = cmap(val)
    luminance = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
    return 'black' if luminance > 0.5 else 'white'

for exp_path in exp_files:
    prefix = os.path.basename(exp_path).replace(".txt", "")
    print(f"Processing: {prefix}")

    # Initialize empty matrices for each model
    matrices = {model: np.full((15, 15), np.nan) for model in model_types}

    # Read lines
    with open(exp_path, "r") as f:
        lines = f.readlines()

    current_model = None
    current_mv_pair = None

    # Parse
    for line in lines:
        for model in model_types:
            if model in line:
                current_model = model
                mv_match = re.search(r"mv\s*:\s*(\[.*?\])", line)
                if mv_match:
                    mv_pair = eval(mv_match.group(1))
                    if len(mv_pair) == 2 and mv_pair[0] in dataset_idx and mv_pair[1] in dataset_idx:
                        current_mv_pair = (dataset_idx[mv_pair[0]], dataset_idx[mv_pair[1]])
                    else:
                        current_mv_pair = None
                break

        if "last performance mu" in line and current_model and current_mv_pair:
            match = re.search(r"last performance mu:\s*([0-9.eE+-]+)", line)
            if match:
                mean_val = float(match.group(1))
                x, y = current_mv_pair
                matrices[current_model][x, y] = mean_val
                matrices[current_model][y, x] = mean_val  # symmetric

    # Plot each model
    for model, matrix in matrices.items():
        if np.isnan(matrix).all():
            print(f"Skipping {model} (empty matrix)")
            continue

        # Pick color map
        if "hivprot" in model:
            cmap = plt.cm.Reds
        elif "dpp4" in model:
            cmap = plt.cm.Greens
        elif "nk1" in model:
            cmap = plt.cm.Blues
        else:
            cmap = plt.cm.Greys_r

        fig, ax = plt.subplots(figsize=(7, 7))

        # Normalize for cmap
        vmin = np.nanmax(matrix)
        vmax = np.nanmin(matrix)
        im = ax.imshow(matrix, cmap=cmap, vmin=vmax, vmax=vmin)

        ax.set_title(f"{prefix} - {model} μ", fontsize=14)
        ax.set_xticks(np.arange(15))
        ax.set_yticks(np.arange(15))
        ax.set_xticklabels(dataset_names, rotation=90, fontsize=10)
        ax.set_yticklabels(dataset_names, fontsize=10)

        # Minor grid
        ax.set_xticks(np.arange(-0.5, 15, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, 15, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.5)
        ax.tick_params(which='minor', bottom=False, left=False)

        # Text annotations
        for i in range(15):
            for j in range(15):
                if not np.isnan(matrix[i, j]):
                    val = matrix[i, j]
                    normalized_val = (val - np.nanmin(matrix)) / (np.nanmax(matrix) - np.nanmin(matrix)) if vmax != vmin else 0.5
                    text_color = get_text_color(normalized_val, cmap)
                    truncated_val = str(val)[:str(val).find('.')+4]  # keep 3 decimals
                    ax.text(j, i, truncated_val, ha='center', va='center', color=text_color, fontsize=7, weight='bold')

        # Save
        model_dir = os.path.join(output_base_dir, model.replace(" ", "_"))
        os.makedirs(model_dir, exist_ok=True)
        plt.tight_layout()
        plt.savefig(os.path.join(model_dir, f"{prefix}.png"))
        plt.savefig(os.path.join(model_dir, f"{prefix}.pdf"))
        plt.close(fig)

        print(f"✅ Saved for {model} - {prefix}")


Processing: S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
✅ Saved for hivprot count - S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
✅ Saved for hivprot bit - S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
✅ Saved for dpp4 count - S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
✅ Saved for dpp4 bit - S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
✅ Saved for nk1 count - S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
✅ Saved for nk1 bit - S-results-dsets_mvalid-all-3real-ml0-xmix-mvdef1-mNctFalse-RYV1_real
