In [1]:
#!/usr/bin/env python
"""
Explainability analysis for trained Transformer model.

Techniques:
1. Attention Visualization - Shows which time steps the model focuses on
2. Integrated Gradients - Shows which features drive predictions

Usage:
    python explainability_analysis.py
"""

'\nExplainability analysis for trained Transformer model.\n\nTechniques:\n1. Attention Visualization - Shows which time steps the model focuses on\n2. Integrated Gradients - Shows which features drive predictions\n\nUsage:\n    python explainability_analysis.py\n'

In [2]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
# Import your model architecture and dataset
# Adjust these imports based on your actual file structure
# from your_training_script import TinyMelTransformer, MelTransformerDataset, collate_variable_length


# ==========================
# Config - Update these paths!
# ==========================
MODEL_PATH = "/scratch/jbfrantz/jess_explainability_model/jess_outputs/best_female_transformer.pt"
METADATA_PATH = "/scratch/jbfrantz/jess_explainability_model/jess_outputs/dataset_metadata.npz"
OUT_DIR = "/scratch/jbfrantz/jess_explainability_model/explainability_outputs"
os.makedirs(OUT_DIR, exist_ok=True)

# Data paths (same as training)
MEL_DIR = "/scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum"
LABEL_DIR = "/scratch/sshuvo13/project_shared_folder_bspml_1/rml_analysis/fixed_rml_analysis/labels_again/fixed_30s_label_outputs"



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_SAMPLES_TO_EXPLAIN = 20  # Number of test samples to analyze

In [11]:
def find_file_pairs(mel_dir, label_dir):
    mel_files = sorted(glob.glob(os.path.join(mel_dir, "*.npy")))
    print(f"[INFO] Total mel files found: {len(mel_files)}")

    pairs = []
    for mel_path in mel_files:
        mel_name = os.path.basename(mel_path)
        stem = os.path.splitext(mel_name)[0]  # e.g. 00001006-100507
        label_name = f"{stem}_segments_labels.npy"
        label_path = os.path.join(label_dir, label_name)

        if os.path.exists(label_path):
            pairs.append((mel_path, label_path))
        else:
            print(f"[WARN] Missing label for {mel_name} -> {label_name}")

    print(f"[INFO] Usable (mel, label) pairs: {len(pairs)}")
    return pairs

def split_file_pairs(file_pairs, seed=42):
    rng = random.Random(seed)
    rng.shuffle(file_pairs)
    n = len(file_pairs)
    n_train = int(0.7 * n)
    n_val = int(0.2 * n)
    n_test = n - n_train - n_val

    train_pairs = file_pairs[:n_train]
    val_pairs = file_pairs[n_train:n_train + n_val]
    test_pairs = file_pairs[n_train + n_val:]

    print(f"[INFO] Patient-level split (files):")
    print(f"       Train: {len(train_pairs)}")
    print(f"       Val:   {len(val_pairs)}")
    print(f"       Test:  {len(test_pairs)}")
    return train_pairs, val_pairs, test_pairs

# ==========================
# Dataset: Transformer-ready sequences
# ==========================
class MelTransformerDataset(Dataset):
    """
    Optimized version: preprocess all segments in __init__ for fast training.
    
    For each segment:
      - mel: (C, n_mels, frames)
      - time pooling by factor TIME_POOL
      - flatten C and n_mels -> feature_dim = C * n_mels
      - output sequence: (T', feature_dim)

    Labels:
      - Normal -> 0, all other classes -> 1 (event)
    """

    def __init__(self, file_pairs, time_pool=10):
        self.time_pool = time_pool
        self.sequences = []  # Store preprocessed sequences
        self.labels = []     # Store corresponding labels
        self.skipped = []
        self.label_to_int = None
        all_label_strings = []
        self.uses_string_labels = False

        total_segments = 0
        
        for mel_path, label_path in file_pairs:
            # Load mel - no mmap, load fully into memory
            try:
                X = np.load(mel_path, allow_pickle=True)
            except Exception as e:
                print(f"[WARN] Skipping MEL file: {mel_path}, reason: {repr(e)}")
                self.skipped.append((mel_path, label_path))
                continue

            # Load labels
            try:
                y = np.load(label_path, allow_pickle=True)
            except Exception as e:
                print(f"[WARN] Skipping LABEL file: {label_path}, reason: {repr(e)}")
                self.skipped.append((mel_path, label_path))
                continue

            if y.ndim > 1:
                y = np.squeeze(y)

            N_mel = X.shape[0]
            N_lab = y.shape[0]
            if N_mel != N_lab:
                N = min(N_mel, N_lab)
                print(
                    f"[WARN] Length mismatch for {mel_path} vs {label_path}: "
                    f"mel={N_mel}, label={N_lab}. Using first {N} segments."
                )
                X = X[:N]
                y = y[:N]
            else:
                N = N_mel

            if y.dtype.kind in {"U", "S", "O"}:
                self.uses_string_labels = True
                all_label_strings.extend(y.tolist())

            # Preprocess ALL segments from this file
            for i in range(N):
                seq = self._segment_to_sequence(X[i])
                self.sequences.append(seq)
                self.labels.append(y[i])
            
            total_segments += N
            print(f"[INFO] Processed {mel_path}: {N} segments, shape={X.shape[1:]}")

        if total_segments == 0:
            raise RuntimeError("No valid female segments loaded!")

        print(f"[INFO] Total female segments: {total_segments}")

        # Set up label mapping
        if self.uses_string_labels:
            unique = sorted(set(all_label_strings))
            print(f"[INFO] Found label classes: {unique}")
            # Normal -> 0, all others -> 1
            self.label_to_int = {lab: (0 if lab == "Normal" else 1) for lab in unique}
            print("[INFO] Binary label mapping:")
            print("       Normal -> 0")
            for lab in unique:
                if lab != "Normal":
                    print(f"       {lab} -> 1 (event)")
        else:
            self.label_to_int = None

    def __len__(self):
        return len(self.sequences)

    def _segment_to_sequence(self, seg):
        """
        seg: (C, n_mels, frames)
        Apply time pooling and flatten freq+channel.
        Return: (T', feature_dim) float32
        """
        C, n_mels, T = seg.shape

        # Time pooling
        if self.time_pool > 1 and T >= self.time_pool:
            T_new = T // self.time_pool
            T_use = T_new * self.time_pool
            seg = seg[:, :, :T_use]
            seg = seg.reshape(C, n_mels, T_new, self.time_pool).mean(axis=-1)
            # (C, n_mels, T_new)
        else:
            T_new = T

        seg = seg.reshape(C * n_mels, T_new)   # (feat_dim, T_new)
        seg = seg.transpose(1, 0)             # (T_new, feat_dim)
        return seg.astype(np.float32)

    def __getitem__(self, idx):
        # Simple lookup - no bisect, no on-the-fly processing!
        seq = self.sequences[idx]
        y_np = self.labels[idx]
        
        x = torch.from_numpy(seq)

        if self.label_to_int is not None:
            y_bin = float(self.label_to_int[str(y_np)])
        else:
            y_bin = float(y_np)

        y = torch.tensor(y_bin, dtype=torch.float32)

        return x, y

def collate_variable_length(batch):
    """
    batch: list of (seq, label)
    seq: (T_i, feat_dim)
    Returns:
      X_padded: (B, max_T, feat_dim)
      y: (B, 1)
    """
    seqs, labels = zip(*batch)
    lengths = [s.shape[0] for s in seqs]
    feat_dim = seqs[0].shape[1]

    max_T = max(lengths)
    B = len(seqs)

    X_padded = torch.zeros(B, max_T, feat_dim, dtype=torch.float32)
    for i, s in enumerate(seqs):
        T_i = s.shape[0]
        X_padded[i, :T_i, :] = s

    y = torch.stack(labels).view(-1, 1)
    return X_padded, y


# ==========================
# Modified Model for Attention Extraction
# ==========================
class TinyMelTransformerWithAttention(nn.Module):
    """
    Same as TinyMelTransformer but returns attention weights.
    """
    def __init__(
        self,
        input_dim,
        d_model=64,
        nhead=4,
        num_layers=2,
        dim_feedforward=128,
        dropout=0.1,
    ):
        super().__init__()
        self.proj = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout, max_len=2000)

        # Store encoder layers separately so we can extract attention
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True,
            )
            for _ in range(num_layers)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(64, 1),
        )
        
        self.attention_weights = []  # Store attention from each layer

    def forward(self, x, return_attention=False):
        # x: (B, T, feat_dim)
        x = self.proj(x)
        x = self.pos_encoder(x)
        
        self.attention_weights = []
        
        # Pass through encoder layers and optionally collect attention
        for layer in self.encoder_layers:
            if return_attention:
                # Hook to extract attention weights
                def get_attention_hook(module, input, output):
                    # TransformerEncoderLayer doesn't directly expose attention,
                    # so we need to modify the self_attn module
                    pass
                
                # Note: PyTorch's TransformerEncoderLayer doesn't expose attention by default
                # We'll use a workaround by accessing the self_attn module
                x = layer(x)
            else:
                x = layer(x)
        
        h = x
        pooled = h.mean(dim=1)
        logits = self.classifier(pooled)
        
        return logits, self.attention_weights if return_attention else None


# ==========================
# Tiny Transformer
# ==========================
# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model, dropout=0.1, max_len=2000):
#         super().__init__()
#         self.dropout = nn.Dropout(p=dropout)

#         pe = torch.zeros(max_len, d_model)
#         position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
#         div_term = torch.exp(
#             torch.arange(0, d_model, 2, dtype=torch.float32) * (-np.log(10000.0) / d_model)
#         )
#         pe[:, 0::2] = torch.sin(position * div_term)
#         if d_model % 2 == 1:
#             pe[:, 1::2] = torch.cos(position * div_term[:-1])
#         else:
#             pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0)
#         self.register_buffer("pe", pe)

#     def forward(self, x):
#         """
#         x: (batch_size, seq_len, d_model)
#         """
#         seq_len = x.size(1)
#         x = x + self.pe[:, :seq_len, :]
#         return self.dropout(x)


# class TinyMelTransformer(nn.Module):
#     """
#     Tiny Transformer over mel-spectrogram time axis.
#     Input to forward: (B, T, feat_dim)
#     """

#     def __init__(
#         self,
#         input_dim,
#         d_model=64,
#         nhead=4,
#         num_layers=2,
#         dim_feedforward=128,
#         dropout=0.1,
#     ):
#         super().__init__()
#         self.proj = nn.Linear(input_dim, d_model)
#         self.pos_encoder = PositionalEncoding(d_model, dropout=dropout, max_len=2000)

#         encoder_layer = nn.TransformerEncoderLayer(
#             d_model=d_model,
#             nhead=nhead,
#             dim_feedforward=dim_feedforward,
#             dropout=dropout,
#             batch_first=True,
#         )
#         self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

#         self.classifier = nn.Sequential(
#             nn.Linear(d_model, 64),
#             nn.ReLU(),
#             nn.Dropout(p=0.3),
#             nn.Linear(64, 1),
#         )

#     def forward(self, x):
#         # x: (B, T, feat_dim)
#         x = self.proj(x)
#         x = self.pos_encoder(x)
#         h = self.transformer(x)
#         pooled = h.mean(dim=1)
#         logits = self.classifier(pooled)
#         return logits




class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) * (-np.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)


# ==========================
# Integrated Gradients
# ==========================
# def integrated_gradients(model, input_tensor, target_class, baseline=None, steps=50):
#     """
#     Compute Integrated Gradients for a given input.
    
#     Args:
#         model: PyTorch model
#         input_tensor: (1, T, feat_dim) - single sample
#         target_class: class index to explain (0 or 1)
#         baseline: baseline input (default: zeros)
#         steps: number of interpolation steps
    
#     Returns:
#         attributions: (T, feat_dim) - importance scores
#     """
#     if baseline is None:
#         baseline = torch.zeros_like(input_tensor)
    
#     # Generate interpolated inputs between baseline and actual input
#     alphas = torch.linspace(0, 1, steps).to(input_tensor.device)
    
#     # Interpolate: baseline + alpha * (input - baseline)
#     interpolated_inputs = baseline + alphas.view(-1, 1, 1) * (input_tensor - baseline)
    
#     # Compute gradients for each interpolated input
#     interpolated_inputs.requires_grad_(True)
    
#     model.eval()
#     gradients = []
    
#     for i in range(steps):
#         inp = interpolated_inputs[i:i+1]
#         logits, _ = model(inp)
        
#         # Get gradient for target class
#         model.zero_grad()
#         logits[0, target_class].backward(retain_graph=True)
        
#         gradients.append(inp.grad.detach().clone())
#         inp.grad.zero_()
    
#     # Average gradients
#     avg_gradients = torch.stack(gradients).mean(dim=0)
    
#     # Integrated gradients = (input - baseline) * avg_gradients
#     attributions = (input_tensor - baseline) * avg_gradients
    
#     return attributions.squeeze(0)  # (T, feat_dim)

def integrated_gradients(model, input_tensor, target_class, baseline=None, steps=50):
    """
    Compute Integrated Gradients for a given input.
    
    Args:
        model: PyTorch model
        input_tensor: (1, T, feat_dim) - single sample
        target_class: class index to explain (0 or 1)
        baseline: baseline input (default: zeros)
        steps: number of interpolation steps
    
    Returns:
        attributions: (T, feat_dim) - importance scores
    """
    if baseline is None:
        baseline = torch.zeros_like(input_tensor)
    
    # Generate interpolated inputs between baseline and actual input
    alphas = torch.linspace(0, 1, steps).to(input_tensor.device)
    
    model.eval()
    gradients = []
    
    for i in range(steps):
        # Create interpolated input for this step
        alpha = alphas[i]
        inp = baseline + alpha * (input_tensor - baseline)
        inp.requires_grad_(True)  # Set requires_grad on each individual input
        
        logits, _ = model(inp)
        
        # Get gradient for target class
        model.zero_grad()
        logits[0, target_class].backward()
        
        gradients.append(inp.grad.detach().clone())
    
    # Average gradients
    avg_gradients = torch.stack(gradients).mean(dim=0)
    
    # Integrated gradients = (input - baseline) * avg_gradients
    attributions = (input_tensor - baseline) * avg_gradients
    
    return attributions.squeeze(0)  # (T, feat_dim)


# ==========================
# Attention Rollout
# ==========================
def attention_rollout(attention_weights, head_fusion='mean'):
    """
    Combine attention weights across layers.
    
    Args:
        attention_weights: list of attention tensors from each layer
                          Each tensor: (batch, num_heads, seq_len, seq_len)
        head_fusion: 'mean' or 'max' - how to combine multiple heads
    
    Returns:
        rollout: (seq_len, seq_len) - aggregated attention
    """
    if not attention_weights:
        return None
    
    # Fuse heads
    if head_fusion == 'mean':
        fused_attention = [attn.mean(dim=1) for attn in attention_weights]
    else:
        fused_attention = [attn.max(dim=1)[0] for attn in attention_weights]
    
    # Roll out across layers
    rollout = fused_attention[0]
    for attn in fused_attention[1:]:
        rollout = torch.matmul(attn, rollout)
    
    return rollout


# ==========================
# Visualization Functions
# ==========================
def plot_attention_heatmap(attention_matrix, sample_idx, true_label, pred_label, filename):
    """
    Plot attention weights as a heatmap.
    
    Args:
        attention_matrix: (seq_len, seq_len) numpy array
        sample_idx: sample identifier
        true_label: ground truth label
        pred_label: predicted label
        filename: output filename
    """
    plt.figure(figsize=(10, 4))
    sns.heatmap(attention_matrix, cmap='viridis', cbar=True)
    plt.title(f'Attention Weights - Sample {sample_idx}\nTrue: {true_label}, Pred: {pred_label}')
    plt.xlabel('Key Position (Time Step)')
    plt.ylabel('Query Position (Time Step)')
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.close()
    print(f"[INFO] Saved attention heatmap: {filename}")


def plot_integrated_gradients(attributions, sample_idx, true_label, pred_label, filename):
    """
    Plot Integrated Gradients attribution scores.
    
    Args:
        attributions: (T, feat_dim) numpy array
        sample_idx: sample identifier
        true_label: ground truth label
        pred_label: predicted label
        filename: output filename
    """
    # Aggregate over feature dimension to get per-timestep importance
    timestep_importance = np.abs(attributions).sum(axis=1)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot 1: Per-timestep importance
    ax1.plot(timestep_importance, linewidth=2)
    ax1.set_title(f'Integrated Gradients - Sample {sample_idx}\nTrue: {true_label}, Pred: {pred_label}')
    ax1.set_xlabel('Time Step')
    ax1.set_ylabel('Importance Score')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Full attribution heatmap (time x features)
    im = ax2.imshow(attributions.T, aspect='auto', cmap='RdBu_r', 
                    vmin=-np.abs(attributions).max(), 
                    vmax=np.abs(attributions).max())
    ax2.set_title('Feature-level Attributions')
    ax2.set_xlabel('Time Step')
    ax2.set_ylabel('Feature Dimension')
    plt.colorbar(im, ax=ax2)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.close()
    print(f"[INFO] Saved IG visualization: {filename}")


def plot_combined_analysis(attention_avg, ig_importance, sample_idx, true_label, pred_label, filename):
    """
    Plot both attention and IG on the same time axis for comparison.
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
    
    # Attention (average across all positions)
    ax1.plot(attention_avg, linewidth=2, color='blue', label='Attention')
    ax1.set_title(f'Explainability Analysis - Sample {sample_idx}\nTrue: {true_label}, Pred: {pred_label}')
    ax1.set_ylabel('Avg Attention Weight')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Integrated Gradients
    ax2.plot(ig_importance, linewidth=2, color='red', label='Integrated Gradients')
    ax2.set_xlabel('Time Step')
    ax2.set_ylabel('IG Importance')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.close()
    print(f"[INFO] Saved combined analysis: {filename}")

In [9]:
# ==========================
# Main Analysis
# ==========================
# def main():
#     print(f"[INFO] Using device: {DEVICE}")
#     print(f"[INFO] Loading model from: {MODEL_PATH}")
    
#     # Load model
#     # Note: You'll need to know the input_dim from your training
#     # You can save this info during training or infer it from test data
    
#     # For now, assuming you saved metadata
#     metadata = np.load(METADATA_PATH.replace('best_female_transformer.pt', 'dataset_metadata.npz'))
#     feat_dim = int(metadata['feature_dim'])
    
#     model = TinyMelTransformerWithAttention(
#         input_dim=feat_dim,
#         d_model=64,
#         nhead=4,
#         num_layers=2,
#         dim_feedforward=128,
#         dropout=0.1,
#     ).to(DEVICE)
    
#     # Load trained weights
#     state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
    
#     # Handle potential key mismatches (transformer vs encoder_layers)
#     new_state_dict = {}
#     for k, v in state_dict.items():
#         if k.startswith('transformer.layers'):
#             # Rename transformer.layers.X to encoder_layers.X
#             new_key = k.replace('transformer.layers', 'encoder_layers')
#             new_state_dict[new_key] = v
#         else:
#             new_state_dict[k] = v
    
#     model.load_state_dict(new_state_dict)
#     model.eval()
#     print("[INFO] Model loaded successfully")

#     pairs_all = find_file_pairs(MEL_DIR, LABEL_DIR)
#     train_pairs, val_pairs, test_pairs = split_file_pairs(pairs_all, seed=42)
    
#     # Load test data (you'll need to recreate your test dataset)
#     # This is just a placeholder - adjust based on your actual data loading
#     print("\n[INFO] Loading test data...")
#     # test_dataset = MelTransformerDataset(test_pairs, time_pool=10)
#     # test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

#     test_dataset = MelTransformerDataset(test_pairs, time_pool=10)
#     test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
#     print(f"\n[INFO] Running explainability on {NUM_SAMPLES_TO_EXPLAIN} samples...")
    
#     # Analyze samples
#     sample_count = 0
    
#     # Placeholder for actual data loading
#     for idx, (x, y) in enumerate(test_loader):
#         if sample_count >= NUM_SAMPLES_TO_EXPLAIN:
#             break
    
#     # Example analysis (you'll replace this with actual loop)
#     x = x.to(DEVICE)  # (1, T, feat_dim)
#     y = y.item()
    
#     ## Get prediction
#     with torch.no_grad():
#         logits, _ = model(x)
#         pred_prob = torch.sigmoid(logits).item()
#         pred_label = int(pred_prob >= 0.5)
    
#     # 1. Integrated Gradients
#     ig_attributions = integrated_gradients(model, x, target_class=0, steps=50)
#     ig_attr_np = ig_attributions.cpu().numpy()
    
#     # 2. Visualize
#     ig_importance = np.abs(ig_attr_np).sum(axis=1)
    
#     plot_integrated_gradients(
#         ig_attr_np, 
#         sample_idx=idx,
#         true_label=y,
#         pred_label=pred_label,
#         filename=os.path.join(OUT_DIR, f'ig_sample_{idx}.png')
#     )
    
#     sample_count += 1
    
#     print("\n[INFO] Explainability analysis complete!")
#     print(f"[INFO] Results saved to: {OUT_DIR}")


# if __name__ == "__main__":
#     main()

[INFO] Using device: cpu
[INFO] Loading model from: /scratch/jbfrantz/jess_explainability_model/jess_outputs/best_female_transformer.pt
[INFO] Model loaded successfully
[INFO] Total mel files found: 71
[INFO] Usable (mel, label) pairs: 71
[INFO] Patient-level split (files):
       Train: 49
       Val:   14
       Test:  8

[INFO] Loading test data...
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001157-100507.npy: 1272 segments, shape=(3, 64, 3001)
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001163-100507.npy: 1572 segments, shape=(3, 64, 3001)
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001198-100507.npy: 1599 segments, shape=(3, 64, 3001)
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001285-100507.npy: 1502 segments, shape=(3, 64, 

In [12]:
def main():
    print(f"[INFO] Using device: {DEVICE}")
    print(f"[INFO] Loading model from: {MODEL_PATH}")
    
    # Load model metadata
    metadata = np.load(METADATA_PATH)
    feat_dim = int(metadata['feature_dim'])
    
    model = TinyMelTransformerWithAttention(
        input_dim=feat_dim,
        d_model=64,
        nhead=4,
        num_layers=2,
        dim_feedforward=128,
        dropout=0.1,
    ).to(DEVICE)
    
    # Load trained weights
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
    
    # Handle potential key mismatches (transformer vs encoder_layers)
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('transformer.layers'):
            new_key = k.replace('transformer.layers', 'encoder_layers')
            new_state_dict[new_key] = v
        else:
            new_state_dict[k] = v
    
    model.load_state_dict(new_state_dict)
    model.eval()
    print("[INFO] Model loaded successfully")
    
    # Load test data
    pairs_all = find_file_pairs(MEL_DIR, LABEL_DIR)
    train_pairs, val_pairs, test_pairs = split_file_pairs(pairs_all, seed=42)
    
    print("\n[INFO] Loading test data...")
    test_dataset = MelTransformerDataset(test_pairs, time_pool=10)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    print(f"\n[INFO] Running explainability on {NUM_SAMPLES_TO_EXPLAIN} samples...")
    
    # Analyze samples
    sample_count = 0
    
    for idx, (x, y) in enumerate(test_loader):
        if sample_count >= NUM_SAMPLES_TO_EXPLAIN:
            break
        
        # Move to device
        x = x.to(DEVICE)  # (1, T, feat_dim)
        y = y.item()
        
        # Get prediction
        with torch.no_grad():
            logits, _ = model(x)
            pred_prob = torch.sigmoid(logits).item()
            pred_label = int(pred_prob >= 0.5)
        
        # 1. Integrated Gradients
        ig_attributions = integrated_gradients(model, x, target_class=0, steps=50)
        ig_attr_np = ig_attributions.cpu().numpy()
        
        # 2. Visualize
        ig_importance = np.abs(ig_attr_np).sum(axis=1)
        
        plot_integrated_gradients(
            ig_attr_np, 
            sample_idx=idx,
            true_label=y,
            pred_label=pred_label,
            filename=os.path.join(OUT_DIR, f'ig_sample_{idx}.png')
        )
        
        sample_count += 1
    
    print("\n[INFO] Explainability analysis complete!")
    print(f"[INFO] Results saved to: {OUT_DIR}")

if __name__ == "__main__":
    main()

[INFO] Using device: cpu
[INFO] Loading model from: /scratch/jbfrantz/jess_explainability_model/jess_outputs/best_female_transformer.pt
[INFO] Model loaded successfully
[INFO] Total mel files found: 71
[INFO] Usable (mel, label) pairs: 71
[INFO] Patient-level split (files):
       Train: 49
       Val:   14
       Test:  8

[INFO] Loading test data...
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001157-100507.npy: 1272 segments, shape=(3, 64, 3001)
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001163-100507.npy: 1572 segments, shape=(3, 64, 3001)
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001198-100507.npy: 1599 segments, shape=(3, 64, 3001)
[INFO] Processed /scratch/sshuvo13/project_shared_folder_bspml_1/segments_30s/features/female/mel_spectrum/00001285-100507.npy: 1502 segments, shape=(3, 64, 