In [5]:
!pip install seaborn
import torch

def compute_channel_mean_std(tensor, channel_indices):
    """Compute mean and std per channel for specified indices."""
    means = []
    stds = []
    for idx in channel_indices:
        data = tensor[:, idx, :]
        means.append(data.mean())
        stds.append(data.std())
    return torch.stack(means), torch.stack(stds)

def apply_z_score_normalization(tensor, channel_indices, means, stds):
    """Apply Z-score normalization using precomputed means and stds."""
    tensor = tensor.clone()
    for i, idx in enumerate(channel_indices):
        mean = means[i]
        std = stds[i]
        if std == 0:
            std = 1.0
        tensor[:, idx, :] = (tensor[:, idx, :] - mean) / std
    return tensor

import matplotlib.pyplot as plt

def plot_normalized_heatmap(X, sample_idx=0, channels=range(18)):
    """
    Plot a heatmap of selected channels for a given sample from the normalized tensor.

    Args:
        X (torch.Tensor): Normalized input tensor of shape [N, 18, 23].
        sample_idx (int): Index of the sample to visualize.
        channels (list or range): Which channels to plot (default: first 10 continuous channels).

    """
    # Safety checks
    if not isinstance(X, torch.Tensor):
        raise ValueError("Input X must be a torch.Tensor.")
    
    if sample_idx < 0 or sample_idx >= X.shape[0]:
        raise IndexError(f"Sample index {sample_idx} out of range. Total samples: {X.shape[0]}")
    
    sample = X[sample_idx, channels, :]  # [len(channels), 23]
    
    # Detach if needed
    if sample.requires_grad:
        sample = sample.detach()
    
    # Move to CPU if needed
    if sample.is_cuda:
        sample = sample.cpu()
    
    # Plot
    plt.figure(figsize=(10, 6))
    plt.imshow(sample.numpy(), aspect='auto', interpolation='nearest')
    plt.colorbar(label='Z-score normalized signal')
    plt.title(f'Heatmap of Sample {sample_idx} (Channels {channels.start}-{channels.stop-1})')
    plt.xlabel('Sequence Position (0 → 22)')
    plt.ylabel('Channel')
    plt.tight_layout()
    plt.show()




In [12]:
import os
from collections import defaultdict

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# ----------------------------
# 1) 模型定义（同之前）
# ----------------------------
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, seq_length,
                 hidden_dim=256, num_layers=3,
                 nhead=8, dropout=0.1):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.positional_encoding = nn.Parameter(
            torch.zeros(1, seq_length, hidden_dim))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=nhead,
            dim_feedforward=hidden_dim*2,
            dropout=dropout, activation='relu')
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers)
        self.layernorm = nn.LayerNorm(hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: [B, channels=18, seq=23]
        x = x.transpose(1, 2)                    # → [B, seq, 18]
        x = self.embedding(x) + self.positional_encoding
        x = self.layernorm(x)
        x = x.transpose(0, 1)                    # → [seq, B, hidden]
        x = self.transformer_encoder(x)
        x = x.transpose(0, 1).mean(dim=1)        # → [B, hidden]
        return self.sigmoid(self.fc(x)).squeeze()


# ----------------------------
# 2) 载入 & 随机抽 300 个
# ----------------------------
positives = torch.load('Positive.pt').float()
negatives = torch.load('Negative.pt').float()
num_samples = 300

if positives.size(0) > num_samples:
    positives = positives[torch.randperm(positives.size(0))[:num_samples]]
if negatives.size(0) > num_samples:
    negatives = negatives[torch.randperm(negatives.size(0))[:num_samples]]


# ----------------------------
# 3) 加载模型权重
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerClassifier(input_dim=18, seq_length=23).to(device)
ckpt = torch.load('best_model_1th.pth', map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()


# ----------------------------
# 4) 计算平均注意力（batch_size=1024）
# ----------------------------
def compute_avg_attn(X, model, batch_size=1024):
    accum = defaultdict(list)
    with torch.no_grad():
        for i in range(0, X.size(0), batch_size):
            batch = X[i:i+batch_size].to(device)      # [B,18,23]
            x = batch.transpose(1, 2)                 # [B,23,18]
            x = model.embedding(x) + model.positional_encoding
            x = model.layernorm(x)
            x = x.transpose(0, 1)                     # [23,B,hidden]

            for L, layer in enumerate(model.transformer_encoder.layers):
                attn_out, attn_w = layer.self_attn(
                    x, x, x,
                    need_weights=True,
                    average_attn_weights=False,
                    attn_mask=None,
                    key_padding_mask=None,
                    is_causal=False
                )
                # attn_w: [B, heads, seq, seq]
                accum[L].append(attn_w.cpu())

                x = layer.norm1(x + layer.dropout1(attn_out))
                ff = layer.linear2(
                    layer.dropout(layer.activation(
                        layer.linear1(x))))
                x = layer.norm2(x + layer.dropout2(ff))

    avg = {}
    for L, mats in accum.items():
        all_w = torch.cat(mats, dim=0)           # [total_samples, heads, seq, seq]
        avg[L] = all_w.mean(0).numpy()           # [heads, seq, seq]
    return avg

avg_pos = compute_avg_attn(positives, model)
avg_neg = compute_avg_attn(negatives, model)


# ----------------------------
# 5) 按层保存子图，标题写全 Layer，并标注 Query/Key 位置
# ----------------------------
def save_per_layer_with_labels(avg_attn, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    seq_len = next(iter(avg_attn.values())).shape[1]
    # 设置少量刻度以免拥挤
    ticks = list(range(0, seq_len, 5))
    if seq_len - 1 not in ticks:
        ticks.append(seq_len - 1)

    n_heads = next(iter(avg_attn.values())).shape[0]

    for layer_idx, mats in avg_attn.items():
        fig, axes = plt.subplots(
            1, n_heads,
            figsize=(n_heads * 1.5, 2),
            squeeze=False
        )
        for h in range(n_heads):
            ax = axes[0][h]
            ax.imshow(mats[h], aspect='auto')
            ax.set_xticks(ticks)
            ax.set_xticklabels(ticks, fontsize=4)
            ax.set_xlabel('Key position', fontsize=6)
            ax.set_yticks(ticks)
            ax.set_yticklabels(ticks, fontsize=4)
            ax.set_ylabel('Query position', fontsize=6)
            # Title includes full "Layer" name
            ax.set_title(f'Layer {layer_idx+1} - Head {h+1}', fontsize=6)

        plt.tight_layout()
        plt.savefig(f"{out_dir}/layer_{layer_idx+1}.png", dpi=300)
        plt.close(fig)

# 分别保存正负样本
save_per_layer_with_labels(avg_pos, 'attention_maps/positive_layers')
save_per_layer_with_labels(avg_neg, 'attention_maps/negative_layers')

print("✅ 每层注意力图已保存并标注：")
print("  - attention_maps/positive_layers/layer_1.png … layer_3.png")
print("  - attention_maps/negative_layers/layer_1.png … layer_3.png")


✅ 每层注意力图已保存并标注：
  - attention_maps/positive_layers/layer_1.png … layer_3.png
  - attention_maps/negative_layers/layer_1.png … layer_3.png
