# Feature Importance with Captum

In [2]:
import torch
from importlib import import_module

WORKSPACE_DIR = '/workspaces/msc-thesis-recurrent-health-modeling/'
print(f"Workspace directory: {WORKSPACE_DIR}")

# ---- User inputs (EDIT ME) ----
DATASET_TEST_PT_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/gru_duration_aware/multiple_hosp_patients/test_dataset.pt"  # path to test .pt file
DATASET_VAL_PT_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/gru_duration_aware/multiple_hosp_patients/validation_tuning_dataset.pt"  # path to val .pt file
DATASET_TRAIN_PT_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/gru_duration_aware/multiple_hosp_patients/train_tuning_dataset.pt"  # path to train .pt file
MODEL_MODULE = "recurrent_health_events_prediction.model.RecurrentHealthEventsDL"
MODEL_CLASS_NAME = "GRUNet"   # e.g., GRUNet, AttentionPoolingNet, CrossAttnPoolingNet
CONFIG_PATH = f"{WORKSPACE_DIR}_runs/gru_duration_aware_20251020_145937/model_config.yaml" # path to model config yaml

Workspace directory: /workspaces/msc-thesis-recurrent-health-modeling/


In [3]:
TRAINING_DATA_STATS_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/training_stats.json"  # path to training stats json

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Load datasets and model

In [5]:
if CONFIG_PATH is not None:
    import yaml
    with open(CONFIG_PATH, "r") as f:
        model_cfg = yaml.safe_load(f)
    MODEL_CLASS_NAME = model_cfg.get("model_class", MODEL_CLASS_NAME)
    MODEL_PARAMS = model_cfg.get("model_params")
    print("Loaded config:", CONFIG_PATH)
else:
    print("Using inline MODEL_PARAMS.")
print("Model class:", MODEL_CLASS_NAME)
print("Model params:", MODEL_PARAMS)


Loaded config: /workspaces/msc-thesis-recurrent-health-modeling/_runs/gru_duration_aware_20251020_145937/model_config.yaml
Model class: GRUNet
Model params: {'dropout': 0.2176106168546852, 'hidden_size_head': 64, 'hidden_size_seq': 16, 'input_size_curr': 19, 'input_size_seq': 8, 'num_layers_seq': 1}


In [6]:
model_cfg

{'batch_size': 32,
 'current_feat_cols': ['LOG_HOSPITALIZATION_DAYS',
  'LOG_DAYS_IN_ICU',
  'CHARLSON_INDEX',
  'LOG_NUM_DRUGS',
  'NUM_PROCEDURES',
  'LOG_PARTICIPATION_DAYS',
  'HAS_DIABETES',
  'HAS_COPD',
  'HAS_CONGESTIVE_HF',
  'DISCHARGE_LOCATION_POST_ACUTE_CARE',
  'DISCHARGE_LOCATION_HOME',
  'AGE',
  'GENDER_M',
  'ADMISSION_TYPE_ELECTIVE',
  'ETHNICITY_WHITE',
  'ETHNICITY_BLACK',
  'ETHNICITY_HISPANIC',
  'INSURANCE_MEDICAID',
  'INSURANCE_PRIVATE'],
 'learning_rate': 0.008309969116405266,
 'longitudinal_feat_cols': ['LOG_HOSPITALIZATION_DAYS',
  'LOG_DAYS_IN_ICU',
  'CHARLSON_INDEX',
  'LOG_NUM_DRUGS',
  'NUM_PROCEDURES',
  'DISCHARGE_LOCATION_POST_ACUTE_CARE',
  'ADMISSION_TYPE_ELECTIVE',
  'LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION'],
 'lr_scheduler': 'plateau',
 'max_sequence_length': 4,
 'model_class': 'GRUNet',
 'model_name': 'GRU Duration Aware',
 'model_params': {'dropout': 0.2176106168546852,
  'hidden_size_head': 64,
  'hidden_size_seq': 16,
  'input_size_curr': 19,
  

In [7]:
current_feat_cols = model_cfg.get("current_feat_cols")
print("Current-visit features: ")
for col in current_feat_cols:
    print(" -", col)

Current-visit features: 
 - LOG_HOSPITALIZATION_DAYS
 - LOG_DAYS_IN_ICU
 - CHARLSON_INDEX
 - LOG_NUM_DRUGS
 - NUM_PROCEDURES
 - LOG_PARTICIPATION_DAYS
 - HAS_DIABETES
 - HAS_COPD
 - HAS_CONGESTIVE_HF
 - DISCHARGE_LOCATION_POST_ACUTE_CARE
 - DISCHARGE_LOCATION_HOME
 - AGE
 - GENDER_M
 - ADMISSION_TYPE_ELECTIVE
 - ETHNICITY_WHITE
 - ETHNICITY_BLACK
 - ETHNICITY_HISPANIC
 - INSURANCE_MEDICAID
 - INSURANCE_PRIVATE


In [8]:
longitudinal_feat_cols = model_cfg.get("longitudinal_feat_cols")
print("Longitudinal features: ")
for col in longitudinal_feat_cols:
    print(" -", col)

Longitudinal features: 
 - LOG_HOSPITALIZATION_DAYS
 - LOG_DAYS_IN_ICU
 - CHARLSON_INDEX
 - LOG_NUM_DRUGS
 - NUM_PROCEDURES
 - DISCHARGE_LOCATION_POST_ACUTE_CARE
 - ADMISSION_TYPE_ELECTIVE
 - LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION


In [9]:
test_dataset_obj = torch.load(DATASET_TEST_PT_PATH, weights_only=False)
val_dataset_obj = torch.load(DATASET_VAL_PT_PATH, weights_only=False)
train_dataset_obj = torch.load(DATASET_TRAIN_PT_PATH, weights_only=False)
print(f"Loaded dataset_object type: {type(test_dataset_obj)}")

Loaded dataset_object type: <class 'recurrent_health_events_prediction.datasets.HospReadmDataset.HospReadmDataset'>


In [10]:
x_curr_ex, x_past_ex, mask_ex, label_ex = train_dataset_obj[0]
print("Shapes of tensors from dataset:")
print(f"x_curr: {x_curr_ex.shape}")
print(f"x_past: {x_past_ex.shape}")
print(f"mask: {mask_ex.shape}")
print(f"label: {label_ex.shape}")

Shapes of tensors from dataset:
x_curr: torch.Size([19])
x_past: torch.Size([4, 8])
mask: torch.Size([4])
label: torch.Size([])


In [11]:
train_loader = torch.utils.data.DataLoader(
    train_dataset_obj, batch_size=model_cfg["batch_size"], shuffle=False
)
val_loader = torch.utils.data.DataLoader(
    val_dataset_obj, batch_size=model_cfg["batch_size"], shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset_obj, batch_size=model_cfg["batch_size"], shuffle=False
)

In [12]:
x_curr, x_past, mask, label = next(iter(train_loader))
print("Shapes of tensors from DataLoader:")
print("x_curr:", x_curr.shape)
print("x_past:", x_past.shape)
print("mask:", mask.shape)
print("label:", label.shape)

Shapes of tensors from DataLoader:
x_curr: torch.Size([32, 19])
x_past: torch.Size([32, 4, 8])
mask: torch.Size([32, 4])
label: torch.Size([32])


In [13]:
# Import model class and instantiate
mod = import_module(MODEL_MODULE)
ModelClass = getattr(mod, MODEL_CLASS_NAME)
model = ModelClass(**MODEL_PARAMS).eval()  # eval mode for graph drawing
print(model.__class__.__name__, "initialized.")

GRUNet initialized.


## Helpers

In [14]:
# Try to build a batch (x_current, x_past, mask_past, labels) in a few common cases
def extract_batch(dataset_obj, batch_size):
    """Return a tuple (x_current, x_past, mask_past, labels).
    Supports a Dataset, a tuple of tensors, or a dict.
    """
    if hasattr(dataset_obj, '__getitem__') and hasattr(dataset_obj, '__len__'):
        # Looks like a Dataset
        sample = dataset_obj[0]
        if isinstance(sample, (list, tuple)) and len(sample) >= 4:
            loader = torch.utils.data.DataLoader(dataset_obj, batch_size=batch_size, shuffle=False)
            batch = next(iter(loader))
            return batch  # expect (x_current, x_past, mask_past, labels)
        else:
            raise ValueError("Dataset sample does not look like (x_current, x_past, mask_past, labels)")

In [15]:
from typing import Dict


@torch.no_grad()
def compute_train_stats(train_loader, max_rows_for_median: int = 200_000):
    """
    Compute mean and median of features in training data.
    It considers only valid (non-masked) entries for the past features.
    Returns:
      {
        "mean_curr":   [D_curr],
        "mean_past":   [D_long],
        "median_curr": [D_curr]  (if collected),
        "median_past": [D_long]  (if collected),
        "has_median":  bool
      }
    """
    sum_curr = None
    sum_past = None
    n_curr = 0
    n_past = 0

    # buffers for median (on CPU)
    buf_curr = []
    buf_past = []
    total_rows_curr = 0
    total_rows_past = 0
    collect_median = True  # we will try; if it exceeds the limit, we fall back

    for batch in train_loader:
        # Support for datasets that return (x_curr, x_past, mask, y) or (x_curr, x_past, mask)
        if len(batch) == 4:
            x_curr, x_past, mask, _ = batch
        else:
            x_curr, x_past, mask = batch

        x_curr = x_curr.to("cpu", non_blocking=True)
        x_past = x_past.to("cpu", non_blocking=True)
        mask = mask.to("cpu", non_blocking=True).bool()

        B, T, D_long = x_past.shape
        D_curr = x_curr.shape[-1]

        # --- means ---
        # current
        sum_curr = (sum_curr if sum_curr is not None else torch.zeros(D_curr)) + x_curr.sum(dim=0)
        n_curr += x_curr.size(0)

        # past: use only valid rows
        valid_rows = x_past[mask]        # [N_valid, D_long]
        if valid_rows.numel() > 0:
            sum_past = (sum_past if sum_past is not None else torch.zeros(D_long)) + valid_rows.sum(dim=0)
            n_past += valid_rows.shape[0]

        # --- medianas (opcional, até limite de linhas) ---
        if collect_median:
            # current
            if total_rows_curr + x_curr.size(0) <= max_rows_for_median:
                buf_curr.append(x_curr)
                total_rows_curr += x_curr.size(0)
            else:
                collect_median = False
            # past
            if valid_rows.numel() > 0:
                if total_rows_past + valid_rows.shape[0] <= max_rows_for_median:
                    buf_past.append(valid_rows)
                    total_rows_past += valid_rows.shape[0]
                else:
                    collect_median = False

    mean_curr = sum_curr / max(n_curr, 1)
    mean_past = sum_past / max(n_past, 1)

    stats = {
        "mean_curr": mean_curr,
        "mean_past": mean_past,
        "has_median": False,
    }

    if collect_median and len(buf_curr) > 0 and len(buf_past) > 0:
        curr_mat = torch.cat(buf_curr, dim=0)          # [N_curr, D_curr]
        past_mat = torch.cat(buf_past, dim=0)          # [N_past, D_long]
        median_curr = torch.quantile(curr_mat, 0.5, dim=0)   # [D_curr]
        median_past = torch.quantile(past_mat, 0.5, dim=0)   # [D_long]
        stats.update({
            "median_curr": median_curr,
            "median_past": median_past,
            "has_median": True
        })

    return stats

def make_baselines(x_curr, x_past, mask, strategy="zeros", stats=None):
    """
    Create baseline tensors for current and past features according to strategy.
    Args:
      x_curr: [B, D_curr] tensor of current features
      x_past: [B, T, D_long] tensor of past features
      mask:   [B, T] boolean tensor indicating valid past steps
      strategy: "zeros", "means", or "medians"
      stats: precomputed statistics dict from compute_train_stats (required for "means" or "medians")
    Returns:
      base_curr: [B, D_curr] baseline tensor for current features
      base_past: [B, T, D_long] baseline tensor for past features
    """
    if strategy == "zeros":
        base_curr = torch.zeros_like(x_curr)
        base_past = torch.zeros_like(x_past)

    elif strategy == "means":
        assert stats is not None and "mean_curr" in stats and "mean_past" in stats, \
            "Passe stats com mean_curr/mean_past para strategy='means'."
        base_curr = stats["mean_curr"].to(x_curr).expand_as(x_curr).clone()
        base_past = stats["mean_past"].to(x_past).expand_as(x_past).clone()

    elif strategy == "medians":
        assert stats is not None and stats.get("has_median", False), \
            "Medianas não disponíveis (a coleta pode ter sido desativada por limite)."
        base_curr = stats["median_curr"].to(x_curr).expand_as(x_curr).clone()
        base_past = stats["median_past"].to(x_past).expand_as(x_past).clone()

    else:
        raise ValueError(f"Estratégia de baseline desconhecida: {strategy}")

    # Zerar baseline nos passos preenchidos (padding) segundo a máscara
    base_past = base_past.masked_fill(~mask.unsqueeze(-1), 0.0)
    return base_curr, base_past

def forward_for_attr(model, x_curr, x_past, mask):
    out = model(x_current=x_curr, x_past=x_past, mask_past=mask)
    if isinstance(out, tuple):
        logits, _ = out
    else:
        logits = out
    return logits

def save_training_stats(stats: Dict[str, torch.Tensor], out_json_path: str):
    import json
    serializable_stats = {}
    for k, v in stats.items():
        if isinstance(v, torch.Tensor):
            serializable_stats[k] = v.cpu().tolist()
        else:
            serializable_stats[k] = v
    with open(out_json_path, "w") as f:
        json.dump(serializable_stats, f, indent=2)
    print(f"Training data stats saved to: {out_json_path}")

## Global Feature Importance

In [16]:
from recurrent_health_events_prediction.model.explainers import global_feature_importance

curr_attr_df, past_attr_df, mean_abs_time = (
    global_feature_importance(model, train_loader, val_loader, current_feat_cols, longitudinal_feat_cols, device)
)

mean_curr shape: torch.Size([19])
mean_past shape: torch.Size([8])
has_median: True


In [17]:
curr_attr_df.columns

Index(['feature', 'attribution_activity', 'attribution_direction'], dtype='object')

In [18]:
from recurrent_health_events_prediction.training.utils_deep_learning import (
    plot_feature_attributions,
)

plot_feature_attributions(
    curr_attr_df,
    title="Current Visit Features - Global Attributions",
    feature_col="feature",
    attr_col="attribution_activity"
)

In [19]:
plot_feature_attributions(
    curr_attr_df,
    title="Current Visit Features - Global Attributions (Direction)",
    feature_col="feature",
    attr_col="attribution_direction",
    top_k=10
)

In [20]:
from recurrent_health_events_prediction.training.utils_deep_learning import (
    plot_feature_attributions,
)

plot_feature_attributions(
    past_attr_df,
    title="Past Visit Features - Global Attributions",
    feature_col="feature",
    attr_col="attribution_activity",
    top_k=10,
)

In [22]:
import numpy as np
import plotly.graph_objects as go

x = np.arange(1, len(mean_abs_time) + 1)

fig = go.Figure(go.Scatter(x=x, y=mean_abs_time, mode="lines+markers", marker=dict(size=6)))
fig.update_layout(
    title="Importance over time",
    xaxis_title="Step",
    yaxis_title="Mean absolute importance",
    template="plotly_white",
    width=800,
    height=400,
)
fig.show()

## Global Feature Importance - 4 past hospitalizations

In [30]:
from torch.utils.data import Subset

def filter_full_mask(dataset):
    indices_full_mask = []

    for idx, sample in enumerate(dataset):
        try:
            x_curr, x_past, mask, label = sample
        except Exception:
            # skip samples with unexpected format
            continue

        # normalize mask shape to [T]
        if mask.dim() == 3 and mask.shape[-1] == 1:
            mask = mask.squeeze(-1)
        mask = mask.bool()

        # check if all time steps are valid
        if mask.numel() > 0 and mask.all():
            indices_full_mask.append(idx)

    print(f"Found {len(indices_full_mask)} samples with all-true mask.")
    if indices_full_mask:
        print("Sample indices (preview):", indices_full_mask[:20])

    # create and return filtered subset
    return Subset(dataset, indices_full_mask)


In [24]:
full_mask_test_dataset = filter_full_mask(test_dataset_obj)
full_mask_train_dataset = filter_full_mask(train_dataset_obj)

Found 144 samples with all-true mask.
Sample indices (preview): [7, 8, 36, 48, 49, 63, 78, 91, 126, 157, 161, 162, 171, 172, 173, 202, 203, 204, 205, 206]
Found 600 samples with all-true mask.
Sample indices (preview): [43, 44, 45, 46, 89, 90, 91, 92, 110, 111, 112, 113, 114, 115, 155, 156, 178, 188, 209, 210]


In [31]:
full_mask_test_loader = torch.utils.data.DataLoader(
    full_mask_test_dataset, batch_size=model_cfg["batch_size"], shuffle=False
)
full_mask_train_dataset_loader = torch.utils.data.DataLoader(
    full_mask_train_dataset, batch_size=model_cfg["batch_size"], shuffle=False
)

In [32]:
curr_attr_df, past_attr_df, mean_abs_time = (
    global_feature_importance(model, train_loader, val_loader, current_feat_cols, longitudinal_feat_cols, device)
)

mean_curr shape: torch.Size([19])
mean_past shape: torch.Size([8])
has_median: True


In [33]:
from recurrent_health_events_prediction.model_selection.deep_learning.utils import plot_attribution_over_time


plot_attribution_over_time(mean_abs_time, title="Absolute IG Attributions - Time Indexes")

## Feature Importance per Sample

In [37]:
import os
import json
import torch

if os.path.exists(TRAINING_DATA_STATS_PATH):
    with open(TRAINING_DATA_STATS_PATH, "r") as f:
        raw = json.load(f)
    stats = {}
    for k, v in raw.items():
        # convert lists back to tensors when appropriate
        if isinstance(v, list):
            stats[k] = torch.tensor(v)
        else:
            stats[k] = v
    print(f"Loaded training stats from: {TRAINING_DATA_STATS_PATH}")
else:
    print(f"No training stats file found at: {TRAINING_DATA_STATS_PATH}. Will compute stats.")
    # If you want means/medians:
    stats = compute_train_stats(train_loader, max_rows_for_median=200_000)
    save_training_stats(stats, TRAINING_DATA_STATS_PATH)

Loaded training stats from: /workspaces/msc-thesis-recurrent-health-modeling/_models/mimic/deep_learning/training_stats.json


In [81]:
# get a batch to explain
batch_example = extract_batch(test_dataset_obj, batch_size=8)
x_curr_b, x_past_b, mask_b, _ = batch_example

# (Optional) layer to split history vs current — for your AttentionPoolingNet:
fc1_layer = model.classifier_head[0]  # nn.Linear

In [82]:
from recurrent_health_events_prediction.model.explainers import explain_deep_learning_model_feat

df_curr, df_past, df_split  = explain_deep_learning_model_feat(
    model,
    x_curr_b,
    x_past_b,
    mask_b,
    current_feat_cols,
    longitudinal_feat_cols,
    layer_for_split=fc1_layer,
    baseline_strategy="means",
    stats=stats,
    n_steps=64,
    internal_batch_size=16,
)

In [83]:
df_curr.sort_values(by="sample_idx", ascending=True)

Unnamed: 0,sample_idx,feature,attribution
0,0,LOG_HOSPITALIZATION_DAYS,-0.000920
1,0,LOG_DAYS_IN_ICU,0.045774
2,0,CHARLSON_INDEX,0.034514
3,0,LOG_NUM_DRUGS,-0.035120
4,0,NUM_PROCEDURES,0.041049
...,...,...,...
147,7,ETHNICITY_WHITE,0.024027
148,7,ETHNICITY_BLACK,-0.009887
149,7,ETHNICITY_HISPANIC,-0.000606
150,7,INSURANCE_MEDICAID,-0.002265


In [84]:
feature_values_df = pd.DataFrame({
    "sample_idx": [i for i in range(len(x_curr_b)) for _ in current_feat_cols],
    "feature": current_feat_cols * len(x_curr_b),
    "value": x_curr_b.flatten().detach().numpy(),
})

In [85]:
merged_df = pd.merge(df_curr, feature_values_df, on=["sample_idx", "feature"])
merged_df.head(20)

Unnamed: 0,sample_idx,feature,attribution,value
0,0,LOG_HOSPITALIZATION_DAYS,-0.00092,0.055471
1,0,LOG_DAYS_IN_ICU,0.045774,-0.798306
2,0,CHARLSON_INDEX,0.034514,-0.785789
3,0,LOG_NUM_DRUGS,-0.03512,0.64024
4,0,NUM_PROCEDURES,0.041049,1.175824
5,0,LOG_PARTICIPATION_DAYS,-0.007616,-1.238572
6,0,HAS_DIABETES,-0.000227,0.0
7,0,HAS_COPD,0.01288,1.0
8,0,HAS_CONGESTIVE_HF,-0.016857,0.0
9,0,DISCHARGE_LOCATION_POST_ACUTE_CARE,0.009409,0.0


In [None]:
merged_df.to_csv("local_attributions_current_features.csv", index=False)

In [86]:
sign_consistency = merged_df.groupby('feature')['attribution'].apply(lambda x: (x > 0).mean())
print(sign_consistency.sort_values())

feature
ADMISSION_TYPE_ELECTIVE               0.000
ETHNICITY_BLACK                       0.000
ETHNICITY_HISPANIC                    0.125
HAS_CONGESTIVE_HF                     0.125
HAS_COPD                              0.125
INSURANCE_MEDICAID                    0.125
LOG_PARTICIPATION_DAYS                0.250
HAS_DIABETES                          0.500
LOG_NUM_DRUGS                         0.500
NUM_PROCEDURES                        0.500
DISCHARGE_LOCATION_HOME               0.500
AGE                                   0.625
CHARLSON_INDEX                        0.625
INSURANCE_PRIVATE                     0.750
LOG_HOSPITALIZATION_DAYS              0.750
ETHNICITY_WHITE                       0.750
LOG_DAYS_IN_ICU                       0.875
DISCHARGE_LOCATION_POST_ACUTE_CARE    1.000
GENDER_M                              1.000
Name: attribution, dtype: float64


In [None]:
df_past.sort_values(by="importance", ascending=False).head(10)

In [28]:
df_split.iloc[0].to_dict()

{'sample_idx': 0.0,
 'past_attribution': 0.08032890434862353,
 'current_attribution': 0.33738889031117253}