In [None]:
import os, sys

project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)

In [None]:
from src.feature_dice_analysis import compute_feature_dice_correlation, plot_feature_dice_results
from src.model import load_model
from src.hooks import target_layers
from src.utils import visualize_and_classify_bright
from src.model import load_model, run_inference, visualize_results
    
from src.data import load_patients, transpose_volumes, PATIENT_IDS, resize_volume, preprocess_volumes
from src.utils import extract_mask, visualize_class, adjust_class_intensity, visualize_adjusted

images, labels = load_patients()
images, labels = transpose_volumes(images, labels)  
model = load_model()
brightness_adjusted_images = adjust_class_intensity(images, labels, PATIENT_IDS, class_name="LV", scale=0.1)

model = load_model()

df = compute_feature_dice_correlation(
    model,
    images,
    brightness_adjusted_images,
    labels,
    PATIENT_IDS,
    target_layers,
    class_name="LV"
)

print(df.head())  
plot_feature_dice_results(df)


In [None]:
from src.feature_dice_analysis import export_feature_dice_results
export_feature_dice_results(df, filepath="feature_dice_results.csv")

In [None]:
from src.feature_dice_analysis import compute_feature_dice_correlation, plot_feature_dice_results
from src.model import load_model
from src.hooks import target_layers
from src.utils import visualize_and_classify_bright
from src.model import load_model, run_inference, visualize_results
    
from src.data import load_patients, transpose_volumes, PATIENT_IDS, resize_volume, preprocess_volumes
from src.utils import extract_mask, visualize_class, adjust_class_intensity, visualize_adjusted

    


images, labels = load_patients()
images, labels = transpose_volumes(images, labels)  
model = load_model()
brightness_adjusted_images = adjust_class_intensity(images, labels, PATIENT_IDS, class_name="LV", scale=0.1)

model = load_model()

In [None]:
from src.feature_dice_analysis import compute_feature_dice_correlation, analyze_feature_dice_relationship

df = compute_feature_dice_correlation(
    model,
    images,
    brightness_adjusted_images,
    labels,
    patient_ids=PATIENT_IDS,
    target_layers=target_layers,
    class_name="LV"
)

channel_scores, layer_scores, top_channels = analyze_feature_dice_relationship(df, topk=3, plot=True)

print("=== Layer-level average correlations ===")
print(layer_scores)

print("\n=== Top 3 channels per layer ===")
print(top_channels.head(300))


In [None]:
import os, sys

project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [None]:
from src.evaluation import summarize_top3_per_patient_slice
from src.feature_plots import plot_layer_boxplots, plot_slice_heatmap, plot_mad_mse_scatter

results_file = "brightness_analysis_results3.txt"

summary = summarize_top3_per_patient_slice(results_file)

In [None]:
from src import feature_plots as fp

fp.plot_layer_boxplots(summary)
fp.plot_slice_heatmap(summary, metric="MAD")
fp.plot_slice_heatmap(summary, metric="MSE")
fp.plot_mad_mse_scatter(summary)
fp.plot_patient_slice_profiles(summary, metric="MAD")
fp.plot_patient_slice_profiles(summary, metric="MSE")
fp.plot_layer_slice_heatmap(summary, metric="MAD")
fp.plot_layer_slice_heatmap(summary, metric="MSE")
fp.plot_channel_frequency(summary, metric="MAD")
fp.plot_channel_frequency(summary, metric="MSE")

In [None]:
from src.feature_plots import plot_top_frequencies_per_layer

plot_top_frequencies_per_layer(summary, metric="MAD")


In [None]:
import re
import os
from collections import defaultdict
import matplotlib.pyplot as plt

results_file = "layer_patient_slice_summary.txt"
save_dir = "meow2/plots"
os.makedirs(save_dir, exist_ok=True)

layer_patient_channels = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

current_layer = None
current_patient = None
current_slice = None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("Patient:"):
            current_patient = line.split("Patient:")[1].strip()
        elif line.startswith("-- Slice"):
            current_slice = int(re.search(r"Slice (\d+)", line).group(1))
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                layer_patient_channels[current_layer][current_patient][ch].append(current_slice)

for layer, patients in layer_patient_channels.items():
    for patient, channels in patients.items():
        counts = {ch: len(slices) for ch, slices in channels.items()}
        top3 = sorted(counts.items(), key=lambda x: x[1], reverse=True)[:3]

        print(f"\n=== {layer} | {patient} ===")
        for ch, _ in top3:
            slices = sorted(channels[ch])
            print(f"Channel {ch} → slices {slices}")

        fig, ax = plt.subplots(figsize=(8, 5))
        chs, freqs = zip(*top3) if top3 else ([], [])
        ax.bar([f"Ch{c}" for c in chs], freqs, color="skyblue", edgecolor="black")

        for i, ch in enumerate(chs):
            ax.text(i, freqs[i] + 0.1, f"Slices {channels[ch]}", 
                    ha="center", va="bottom", fontsize=8, rotation=45)

        ax.set_title(f"Top-3 Channels (MAD) – {layer} | {patient}", fontsize=12)
        ax.set_ylabel("Frequency (num slices)")
        ax.set_xlabel("Channel")
        ax.grid(axis="y", linestyle="--", alpha=0.6)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{layer}_{patient}_top3_MAD.png"))
        plt.close()


In [None]:
import re
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_layer_patient_channel.png"

layer_patient_channels = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
current_layer, current_patient, current_slice = None, None, None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("Patient:"):
            current_patient = line.split("Patient:")[1].strip()
        elif line.startswith("-- Slice"):
            current_slice = int(re.search(r"Slice (\d+)", line).group(1))
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                layer_patient_channels[current_layer][current_patient][ch].append(current_slice)

layers = list(layer_patient_channels.keys())
patients = sorted({p for layer in layer_patient_channels.values() for p in layer})

layer_index = {layer: i for i, layer in enumerate(layers)}
patient_index = {p: i for i, p in enumerate(patients)}

matrix = np.zeros((len(layers), len(patients)))
freq_matrix = np.zeros((len(layers), len(patients)))

for layer, patients_dict in layer_patient_channels.items():
    for patient, channels in patients_dict.items():
        counts = {ch: len(slices) for ch, slices in channels.items()}
        if counts:
            ch = max(counts, key=counts.get)  
            matrix[layer_index[layer], patient_index[patient]] = ch
            freq_matrix[layer_index[layer], patient_index[patient]] = counts[ch]

from matplotlib.cm import get_cmap
from matplotlib.colors import ListedColormap
import numpy as np

cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))

num_colors = int(matrix.max()) + 1
cmap = ListedColormap(colors[:num_colors])  




fig, ax = plt.subplots(figsize=(12, 8))

im = ax.imshow(matrix, cmap=cmap, aspect="auto")


for i in range(matrix.shape[0]):
    for j in range(matrix.shape[1]):
        freq = int(freq_matrix[i, j])
        if freq > 0:
            ax.text(j, i, str(freq), ha="center", va="center", color="black", fontsize=7)


ax.set_xticks(range(len(patients)))
ax.set_xticklabels(patients, rotation=45, ha="right", fontsize=8)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)

ax.set_title("Dominant Channel (color) × Frequency (# inside cell)", fontsize=12)


from matplotlib.patches import Patch
unique_channels = np.unique(matrix.astype(int))
legend_elements = [Patch(facecolor=cmap(ch), label=f"Ch{ch}") for ch in unique_channels]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=6)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
import os
from collections import defaultdict
import matplotlib.pyplot as plt

results_file = "layer_patient_slice_summary.txt"
save_dir = "meow2/plots"
os.makedirs(save_dir, exist_ok=True)


layer_patient_channels = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

current_layer = None
current_patient = None
current_slice = None


with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("Patient:"):
            current_patient = line.split("Patient:")[1].strip()
        elif line.startswith("-- Slice"):
            current_slice = int(re.search(r"Slice (\d+)", line).group(1))
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                layer_patient_channels[current_layer][current_patient][ch].append(current_slice)

for layer, patients in layer_patient_channels.items():
    fig, ax = plt.subplots(figsize=(22, 8)) 

    colors = plt.cm.tab20.colors  
    bar_width = 0.25
    patient_positions = range(len(patients))

    for i, (patient, channels) in enumerate(patients.items()):
        counts = {ch: len(slices) for ch, slices in channels.items()}
        top3 = sorted(counts.items(), key=lambda x: x[1], reverse=True)[:3]

        for j, (ch, freq) in enumerate(top3):
            bar_pos = i + (j - 1) * bar_width  

            slices = sorted(channels[ch])
            bottom = 0
            for s in slices:
                ax.bar(bar_pos, 1, width=bar_width,
                       color=colors[s % len(colors)],
                       edgecolor="black",
                       bottom=bottom)
                bottom += 1

            ax.text(bar_pos, freq + 0.3, f"Ch{ch}",
                    ha="center", va="bottom", fontsize=8, rotation=90)

    ax.set_xticks(list(patient_positions))
    ax.set_xticklabels(list(patients.keys()), rotation=45, ha="right", fontsize=8)
    ax.set_title(f"Top-3 Channels per Patient (MAD) – {layer}", fontsize=14)
    ax.set_ylabel("Frequency (# of slices)")
    ax.grid(axis="y", linestyle="--", alpha=0.6)


    handles = [plt.Rectangle((0,0),1,1,color=colors[s % len(colors)]) for s in range(10)]
    labels = [f"Slice {s}" for s in range(10)]
    ax.legend(handles, labels, title="Slices", bbox_to_anchor=(1.05, 1), loc="upper left")

    ymax = ax.get_ylim()[1]
    ax.set_ylim(0, ymax + 3)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    plt.savefig(os.path.join(save_dir, f"{layer}_top3_channels_per_patient.png"))
    plt.close()


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_layer_channel_counts.png"


counts = defaultdict(Counter)   
current_layer = None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                counts[current_layer][ch] += 1


N = 10
total_counts = Counter()
for layer, ch_counter in counts.items():
    total_counts.update(ch_counter)

top_channels = [ch for ch, _ in total_counts.most_common(N)]


print("Top channels overall:")
for ch, cnt in total_counts.most_common(N):
    print(f"Ch{ch}: {cnt}")


layers = list(counts.keys())
matrix = np.zeros((len(layers), len(top_channels)), dtype=int)

for i, layer in enumerate(layers):
    for j, ch in enumerate(top_channels):
        matrix[i, j] = counts[layer][ch]


matrix_plot = matrix.astype(float)
matrix_plot[matrix_plot == 0] = np.nan


from matplotlib.cm import get_cmap
cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))

cmap = ListedColormap(colors[:N])   
cmap.set_bad("white")               


fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(matrix_plot, cmap=cmap, aspect="auto")


for i in range(matrix.shape[0]):
    for j in range(matrix.shape[1]):
        val = matrix[i, j]
        if val > 0:
            ax.text(j, i, str(val), ha="center", va="center",
                    color="black", fontsize=7)


ax.set_xticks(range(len(top_channels)))
ax.set_xticklabels([f"Ch{ch}" for ch in top_channels], rotation=45, ha="right", fontsize=8)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)

ax.set_title(f"Top {N} Channels × Frequency in Top-3 MAD (per Layer)", fontsize=12)


legend_elements = [Patch(facecolor=cmap(j), label=f"Ch{ch}") 
                   for j, ch in enumerate(top_channels)]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_layer_channel_counts.png"

counts = defaultdict(Counter)  
current_layer = None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                counts[current_layer][ch] += 1

N = 10
total_counts = Counter()
for layer, ch_counter in counts.items():
    total_counts.update(ch_counter)

top_channels = [ch for ch, _ in total_counts.most_common(N)]

print("Top channels overall:")
for ch, cnt in total_counts.most_common(N):
    print(f"Ch{ch}: {cnt}")

layers = list(counts.keys())

count_matrix = np.zeros((len(layers), len(top_channels)), dtype=int)
channel_matrix = np.zeros((len(layers), len(top_channels)), dtype=int)

for i, layer in enumerate(layers):
    for j, ch in enumerate(top_channels):
        count_matrix[i, j] = counts[layer][ch]   
        channel_matrix[i, j] = j                

plot_matrix = channel_matrix.astype(float)
plot_matrix[count_matrix == 0] = np.nan

from matplotlib.cm import get_cmap
cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))

cmap = ListedColormap(colors[:N])  
cmap.set_bad("white")               

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")

for i in range(count_matrix.shape[0]):
    for j in range(count_matrix.shape[1]):
        val = count_matrix[i, j]
        if val > 0:
            ax.text(j, i, str(val), ha="center", va="center",
                    color="black", fontsize=7)

ax.set_xticks(range(len(top_channels)))
ax.set_xticklabels([f"Ch{ch}" for ch in top_channels], rotation=45, ha="right", fontsize=8)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)

ax.set_title(f"Top {N} Channels × Frequency in Top-3 MAD (per Layer)", fontsize=12)

legend_elements = [Patch(facecolor=cmap(j), label=f"Ch{ch}") 
                   for j, ch in enumerate(top_channels)]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
print("\nVerification:")
for j, ch in enumerate(top_channels):
    col_sum = count_matrix[:, j].sum()
    print(f"Ch{ch}: {col_sum} (expected {total_counts[ch]})")


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_layer_top10.png"

rank_counts = defaultdict(lambda: defaultdict(Counter))
current_layer, rank_idx = None, None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("-- Slice"):
            rank_idx = 0  
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                rank_counts[current_layer][rank_idx][ch] += 1
                rank_idx += 1

layers = list(rank_counts.keys())
max_rank = max(max(ranks.keys()) for ranks in rank_counts.values()) + 1
max_rank = min(max_rank, 10)  

count_matrix = np.zeros((len(layers), max_rank), dtype=int)
channel_matrix = np.full((len(layers), max_rank), -1, dtype=int)

for i, layer in enumerate(layers):
    for j in range(max_rank):
        if rank_counts[layer][j]:
            ch, freq = rank_counts[layer][j].most_common(1)[0]
            count_matrix[i, j] = freq
            channel_matrix[i, j] = ch


from matplotlib.cm import get_cmap
cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))

cmap = ListedColormap(colors[:60])  
cmap.set_bad("white")

plot_matrix = channel_matrix.astype(float)
plot_matrix[plot_matrix < 0] = np.nan

fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")

for i in range(count_matrix.shape[0]):
    for j in range(count_matrix.shape[1]):
        val = count_matrix[i, j]
        if val > 0:
            ch = channel_matrix[i, j]
            ax.text(j, i, f"{val}", ha="center", va="center",
                    color="black", fontsize=7)

global_counts = Counter()
for layer, ranks in rank_counts.items():
    for rank, ch_counter in ranks.items():
        global_counts.update(ch_counter)

top_channels_global = [ch for ch, _ in global_counts.most_common(max_rank)]

ax.set_xticks(range(max_rank))
ax.set_xticklabels(
    [f"Top{j+1} (Ch{ch})" for j, ch in enumerate(top_channels_global)],
    rotation=45, ha="right", fontsize=8
)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)

ax.set_title(f"Most Frequent Channel per Rank (Top1–Top{max_rank})", fontsize=12)


unique_channels = np.unique(channel_matrix[channel_matrix >= 0])
legend_elements = [Patch(facecolor=cmap(ch), label=f"Ch{ch}") for ch in unique_channels]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_layer_top7.png"

rank_counts = defaultdict(lambda: defaultdict(Counter))
current_layer, rank_idx = None, None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("-- Slice"):
            rank_idx = 0 
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                rank_counts[current_layer][rank_idx][ch] += 1
                rank_idx += 1

layers = list(rank_counts.keys())
max_rank = 7  

count_matrix = np.zeros((len(layers), max_rank), dtype=int)
channel_matrix = np.full((len(layers), max_rank), -1, dtype=int)

for i, layer in enumerate(layers):
    for j in range(max_rank):
        if rank_counts[layer][j]:
            ch, freq = rank_counts[layer][j].most_common(1)[0]
            count_matrix[i, j] = freq
            channel_matrix[i, j] = ch

from matplotlib.cm import get_cmap
cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))

cmap = ListedColormap(colors[:60])  
cmap.set_bad("white")


plot_matrix = channel_matrix.astype(float)
plot_matrix[plot_matrix < 0] = np.nan


fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")

for i in range(count_matrix.shape[0]):
    for j in range(count_matrix.shape[1]):
        val = count_matrix[i, j]
        if val > 0:
            ax.text(j, i, f"{val}", ha="center", va="center",
                    color="black", fontsize=7)

ax.set_xticks(range(max_rank))
ax.set_xticklabels([f"Top{j+1}" for j in range(max_rank)], fontsize=9)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)

ax.set_title("Most Frequent Channel per Rank (Top1–Top7, per layer)", fontsize=12)

unique_channels = np.unique(channel_matrix[channel_matrix >= 0])
legend_elements = [Patch(facecolor=cmap(ch), label=f"Ch{ch}") for ch in unique_channels]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_rank_based.png"

rank_counts = defaultdict(lambda: defaultdict(Counter))
current_layer, rank_idx = None, None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("-- Slice"):
            rank_idx = 0  
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                rank_counts[current_layer][rank_idx][ch] += 1
                rank_idx += 1

layers = list(rank_counts.keys())
max_rank = 3

count_matrix = np.zeros((len(layers), max_rank), dtype=int)
channel_matrix = np.full((len(layers), max_rank), -1, dtype=int)

for i, layer in enumerate(layers):
    for j in range(max_rank):
        if rank_counts[layer][j]:
            ch, freq = rank_counts[layer][j].most_common(1)[0]
            count_matrix[i, j] = freq
            channel_matrix[i, j] = ch

from matplotlib.cm import get_cmap
cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))
cmap = ListedColormap(colors[:60])
cmap.set_bad("white")

plot_matrix = channel_matrix.astype(float)
plot_matrix[plot_matrix < 0] = np.nan

fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")

for i in range(count_matrix.shape[0]):
    for j in range(count_matrix.shape[1]):
        val = count_matrix[i, j]
        if val > 0:
            ax.text(j, i, str(val), ha="center", va="center",
                    color="black", fontsize=7)

ax.set_xticks(range(max_rank))
ax.set_xticklabels([f"Top{j+1}" for j in range(max_rank)], fontsize=9)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)

ax.set_title("Rank-based: Most Frequent Channel at Rank1–Rank3", fontsize=12)

unique_channels = np.unique(channel_matrix[channel_matrix >= 0])
legend_elements = [Patch(facecolor=cmap(ch), label=f"Ch{ch}") for ch in unique_channels]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

results_file = "layer_patient_slice_summary.txt"
save_path = "heatmap_frequency_based.png"

counts = defaultdict(Counter)
current_layer = None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                counts[current_layer][ch] += 1  


layers = list(counts.keys())
N = 7 

count_matrix = np.zeros((len(layers), N), dtype=int)
channel_matrix = np.full((len(layers), N), -1, dtype=int)

for i, layer in enumerate(layers):
    topN = counts[layer].most_common(N)
    for j, (ch, freq) in enumerate(topN):
        count_matrix[i, j] = freq
        channel_matrix[i, j] = ch

from matplotlib.cm import get_cmap
cmaps = ["tab20", "tab20b", "tab20c"]
colors = []
for cm in cmaps:
    cmap_obj = get_cmap(cm)
    colors.extend(cmap_obj(np.linspace(0, 1, cmap_obj.N)))
cmap = ListedColormap(colors[:60])
cmap.set_bad("white")

plot_matrix = channel_matrix.astype(float)
plot_matrix[plot_matrix < 0] = np.nan

fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")

for i in range(count_matrix.shape[0]):
    for j in range(count_matrix.shape[1]):
        val = count_matrix[i, j]
        if val > 0:
            ax.text(j, i, str(val), ha="center", va="center",
                    color="black", fontsize=10)

ax.set_xticks(range(N))
ax.set_xticklabels([f"Top{j+1}" for j in range(N)], fontsize=9)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=10)

ax.set_title(f"Frequency-based: Per-Layer Top{N} Channels", fontsize=12)

unique_channels = np.unique(channel_matrix[channel_matrix >= 0])
legend_elements = [Patch(facecolor=cmap(ch), label=f"Ch{ch}") for ch in unique_channels]
ax.legend(handles=legend_elements, title="Channels",
          bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
from collections import defaultdict, Counter
import numpy as np

results_file = "layer_patient_slice_summary.txt"

rank_counts = defaultdict(lambda: defaultdict(Counter))
current_layer, rank_idx = None, None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("-- Slice"):
            rank_idx = 0 
        elif line.startswith("MAD → Ch"):
            if rank_idx < 3: 
                match = re.search(r"MAD → Ch (\d+)", line)
                if match:
                    ch = int(match.group(1))
                    rank_counts[current_layer][rank_idx][ch] += 1
                rank_idx += 1


print("\n=== Rank-based counts (Top1, Top2, Top3 per layer) ===")
for layer, ranks in rank_counts.items():
    print(f"\nLayer: {layer}")
    for rank, counter in ranks.items():
        print(f"  Rank {rank+1}: {counter.most_common(5)}") 


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter

results_file = "layer_patient_slice_summary_3.txt"
save_path = "heatmap_frequency_colored_lv.png"

counts = defaultdict(Counter)
current_layer = None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                counts[current_layer][ch] += 1  


layers = list(counts.keys())
N = 7 

count_matrix = np.zeros((len(layers), N), dtype=int)
channel_matrix = np.full((len(layers), N), -1, dtype=int)

for i, layer in enumerate(layers):
    topN = counts[layer].most_common(N)
    for j, (ch, freq) in enumerate(topN):
        count_matrix[i, j] = freq
        channel_matrix[i, j] = ch


from matplotlib.cm import get_cmap
cmap = get_cmap("viridis") 

plot_matrix = count_matrix.astype(float)
plot_matrix[plot_matrix <= 0] = np.nan

fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")

for i in range(channel_matrix.shape[0]):
    for j in range(channel_matrix.shape[1]):
        ch = channel_matrix[i, j]
        if ch >= 0:
            ax.text(j, i, f"Ch{ch}", ha="center", va="center",
                    color="white", fontsize=8, fontweight="bold")

ax.set_xticks(range(N))
ax.set_xticklabels([f"Top{j+1}" for j in range(N)], fontsize=9)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=10)

ax.set_title(f"Frequency-based: Per-Layer Top{N} Channels", fontsize=12)


cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Frequency of Occurrence", rotation=270, labelpad=15)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter

results_file = "layer_patient_slice_summary2_rv_new.txt"
save_path = "heatmap_frequency_colored_rv_new.png"

counts = defaultdict(Counter)
current_layer = None

with open(results_file, "r") as f:
    for line in f:
        line = line.strip()
        if line.startswith("=== Layer:"):
            current_layer = line.split("Layer:")[1].strip(" =")
        elif line.startswith("MAD → Ch"):
            match = re.search(r"MAD → Ch (\d+)", line)
            if match:
                ch = int(match.group(1))
                counts[current_layer][ch] += 1  


layers = list(counts.keys())
N = 7  

count_matrix = np.zeros((len(layers), N), dtype=int)
channel_matrix = np.full((len(layers), N), -1, dtype=int)

for i, layer in enumerate(layers):
    topN = counts[layer].most_common(N)
    for j, (ch, freq) in enumerate(topN):
        count_matrix[i, j] = freq
        channel_matrix[i, j] = ch


from matplotlib.cm import get_cmap
cmap = get_cmap("viridis") 


plot_matrix = count_matrix.astype(float)
plot_matrix[plot_matrix <= 0] = np.nan


fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(plot_matrix, cmap=cmap, aspect="auto")


for i in range(channel_matrix.shape[0]):
    for j in range(channel_matrix.shape[1]):
        ch = channel_matrix[i, j]
        if ch >= 0:
            ax.text(j, i, f"Ch{ch}", ha="center", va="center",
                    color="white", fontsize=8, fontweight="bold")


ax.set_xticks(range(N))
ax.set_xticklabels([f"Top{j+1}" for j in range(N)], fontsize=9)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=10)

ax.set_title(f"Frequency-based: Per-Layer Top{N} Channels", fontsize=12)

cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Frequency of Occurrence", rotation=270, labelpad=15)

plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()


In [None]:
from src.hooks import extract_features_all, analyze_features_all, target_layers


features_orig_dict, features_bright_dict = extract_features_all(
    model,
    PATIENT_IDS,
    images,
    brightness_adjusted_images,
    target_layers,
    num_slices=10
)



def visualize_top3_mad_featuremaps_specific(
    summary,
    features_orig_dict,
    features_bright_dict,
    out_dir="top3_mad_featuremaps1",
    allowed_layers=None,       
    allowed_channels=None       
):
    import os
    import numpy as np
    import matplotlib.pyplot as plt

    os.makedirs(out_dir, exist_ok=True)

    def _to_np(x):
        if hasattr(x, "detach"):
            x = x.detach().cpu().numpy()
        return np.asarray(x)

    def _safe(s):
        return str(s).replace(".", "_").replace("/", "_").replace(" ", "_")

    n_saved = 0
    for layer, per_patient in summary.items():
        if allowed_layers is not None and layer not in allowed_layers:
            continue

        for patient, per_slice in per_patient.items():
            for slice_idx, groups in per_slice.items():
                top_mad_entries = groups.get("top_mad", [])
                if not top_mad_entries:
                    continue

                for e in top_mad_entries:
                    ch   = int(e["Channel"])
                    mad  = float(e["MAD"])
                    mse  = float(e["MSE"])
                    ssim = float(e["SSIM"])

                    if allowed_channels is not None and ch not in allowed_channels:
                        continue

                    fmap_o = features_orig_dict[patient][slice_idx][layer][0, ch]
                    fmap_b = features_bright_dict[patient][slice_idx][layer][0, ch]

                    orig   = _to_np(fmap_o)
                    bright = _to_np(fmap_b)
                    diff   = np.abs(orig - bright)

                    vmin = float(min(orig.min(), bright.min()))
                    vmax = float(max(orig.max(), bright.max()))

                    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
                    fig.suptitle(
                        f"{layer} – Ch {ch}, Slice {slice_idx}, {patient}\n"
                        f"MAD={mad:.4f} | MSE={mse:.4f} | SSIM={ssim:.4f}",
                        fontsize=10
                    )

                    axs[0].imshow(orig, cmap="viridis", vmin=vmin, vmax=vmax)
                    axs[0].set_title("Original"); axs[0].axis("off")

                    axs[1].imshow(bright, cmap="viridis", vmin=vmin, vmax=vmax)
                    axs[1].set_title("Brightness Adjusted"); axs[1].axis("off")

                    axs[2].imshow(diff, cmap="magma")
                    axs[2].set_title("|Orig−Bright|"); axs[2].axis("off")

                    plt.tight_layout()

                    out_path = os.path.join(
                        out_dir,
                        f"path_{_safe(patient)}_layer_{_safe(layer)}_ch{ch:02d}_slice{slice_idx:02d}.png"
                    )
                    plt.savefig(out_path, dpi=150)
                    plt.close(fig)
                    n_saved += 1
    print(f"Saved {n_saved} images to {out_dir}/")
