# Feature Importance with Captum

In [8]:
import torch
from importlib import import_module

WORKSPACE_DIR = '/workspaces/msc-thesis-recurrent-health-modeling/'
print(f"Workspace directory: {WORKSPACE_DIR}")

# ---- User inputs (EDIT ME) ----
DATASET_TEST_PT_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/attention_pooling_query_curr/multiple_hosp_patients/test_dataset.pt"  # path to test .pt file
DATASET_TRAIN_PT_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/attention_pooling_query_curr/multiple_hosp_patients/train_full_dataset.pt"  # path to train .pt file
MODEL_MODULE = "recurrent_health_events_prediction.model.RecurrentHealthEventsDL"
MODEL_CLASS_NAME = "AttentionPoolingNetCurrentQuery"   # e.g., GRUNet, AttentionPoolingNet, CrossAttnPoolingNet
CONFIG_PATH = f"{WORKSPACE_DIR}_runs/attention_pooling_query_curr_20251020_180454/model_config.yaml" # path to model config yaml

Workspace directory: /workspaces/msc-thesis-recurrent-health-modeling/


In [9]:
TRAINING_DATA_STATS_PATH = f"{WORKSPACE_DIR}_models/mimic/deep_learning/training_stats.json"  # path to training stats json

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Load datasets and model

In [13]:
if CONFIG_PATH is not None:
    import yaml
    with open(CONFIG_PATH, "r") as f:
        model_cfg = yaml.safe_load(f)
    MODEL_CLASS_NAME = model_cfg.get("model_class", MODEL_CLASS_NAME)
    MODEL_PARAMS = model_cfg.get("model_params")
    print("Loaded config:", CONFIG_PATH)
else:
    print("Using inline MODEL_PARAMS.")
print("Model class:", MODEL_CLASS_NAME)
print("Model params:", MODEL_PARAMS)


Loaded config: /workspaces/msc-thesis-recurrent-health-modeling/_runs/attention_pooling_query_curr_20251020_180454/model_config.yaml
Model class: AttentionPoolingNetCurrentQuery
Model params: {'dropout': 0.014417996600305944, 'hidden_size_head': 64, 'hidden_size_seq': 16, 'input_size_curr': 19, 'input_size_seq': 8, 'scale_scores': False, 'use_separate_values': True}


In [14]:
model_cfg

{'batch_size': 64,
 'current_feat_cols': ['LOG_HOSPITALIZATION_DAYS',
  'LOG_DAYS_IN_ICU',
  'CHARLSON_INDEX',
  'LOG_NUM_DRUGS',
  'NUM_PROCEDURES',
  'LOG_PARTICIPATION_DAYS',
  'HAS_DIABETES',
  'HAS_COPD',
  'HAS_CONGESTIVE_HF',
  'DISCHARGE_LOCATION_POST_ACUTE_CARE',
  'DISCHARGE_LOCATION_HOME',
  'AGE',
  'GENDER_M',
  'ADMISSION_TYPE_ELECTIVE',
  'ETHNICITY_WHITE',
  'ETHNICITY_BLACK',
  'ETHNICITY_HISPANIC',
  'INSURANCE_MEDICAID',
  'INSURANCE_PRIVATE'],
 'learning_rate': 0.007004792359623531,
 'longitudinal_feat_cols': ['LOG_HOSPITALIZATION_DAYS',
  'LOG_DAYS_IN_ICU',
  'CHARLSON_INDEX',
  'LOG_NUM_DRUGS',
  'NUM_PROCEDURES',
  'DISCHARGE_LOCATION_POST_ACUTE_CARE',
  'ADMISSION_TYPE_ELECTIVE',
  'LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION'],
 'lr_scheduler': 'plateau',
 'max_sequence_length': 4,
 'model_class': 'AttentionPoolingNetCurrentQuery',
 'model_name': 'Attention Pooling with Current-Visit Query',
 'model_params': {'dropout': 0.014417996600305944,
  'hidden_size_head': 64,
 

In [15]:
current_feat_cols = model_cfg.get("current_feat_cols")
print("Current-visit features: ")
for col in current_feat_cols:
    print(" -", col)

Current-visit features: 
 - LOG_HOSPITALIZATION_DAYS
 - LOG_DAYS_IN_ICU
 - CHARLSON_INDEX
 - LOG_NUM_DRUGS
 - NUM_PROCEDURES
 - LOG_PARTICIPATION_DAYS
 - HAS_DIABETES
 - HAS_COPD
 - HAS_CONGESTIVE_HF
 - DISCHARGE_LOCATION_POST_ACUTE_CARE
 - DISCHARGE_LOCATION_HOME
 - AGE
 - GENDER_M
 - ADMISSION_TYPE_ELECTIVE
 - ETHNICITY_WHITE
 - ETHNICITY_BLACK
 - ETHNICITY_HISPANIC
 - INSURANCE_MEDICAID
 - INSURANCE_PRIVATE


In [16]:
longitudinal_feat_cols = model_cfg.get("longitudinal_feat_cols")
print("Longitudinal features: ")
for col in longitudinal_feat_cols:
    print(" -", col)

Longitudinal features: 
 - LOG_HOSPITALIZATION_DAYS
 - LOG_DAYS_IN_ICU
 - CHARLSON_INDEX
 - LOG_NUM_DRUGS
 - NUM_PROCEDURES
 - DISCHARGE_LOCATION_POST_ACUTE_CARE
 - ADMISSION_TYPE_ELECTIVE
 - LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION


In [17]:
test_dataset_obj = torch.load(DATASET_TEST_PT_PATH, weights_only=False)
train_dataset_obj = torch.load(DATASET_TRAIN_PT_PATH, weights_only=False)
print(f"Loaded dataset_object type: {type(test_dataset_obj)}")

Loaded dataset_object type: <class 'recurrent_health_events_prediction.datasets.HospReadmDataset.HospReadmDataset'>


In [18]:
x_curr_ex, x_past_ex, mask_ex, label_ex = train_dataset_obj[0]
print("Shapes of tensors from dataset:")
print(f"x_curr: {x_curr_ex.shape}")
print(f"x_past: {x_past_ex.shape}")
print(f"mask: {mask_ex.shape}")
print(f"label: {label_ex.shape}")

Shapes of tensors from dataset:
x_curr: torch.Size([19])
x_past: torch.Size([4, 8])
mask: torch.Size([4])
label: torch.Size([])


In [19]:
train_loader = torch.utils.data.DataLoader(
    train_dataset_obj, batch_size=model_cfg["batch_size"], shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset_obj, batch_size=model_cfg["batch_size"], shuffle=False
)

In [20]:
x_curr, x_past, mask, label = next(iter(train_loader))
print("Shapes of tensors from DataLoader:")
print("x_curr:", x_curr.shape)
print("x_past:", x_past.shape)
print("mask:", mask.shape)
print("label:", label.shape)

Shapes of tensors from DataLoader:
x_curr: torch.Size([64, 19])
x_past: torch.Size([64, 4, 8])
mask: torch.Size([64, 4])
label: torch.Size([64])


In [21]:
x_past

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0650,  0.9799,  0.0150,  ...,  0.0000,  0.0000,  1.4230],
         [ 1.2172,  0.3097, -0.3854,  ...,  0.0000,  0.0000,  0.9125],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        ...,

        [[-1.2898, -0.8091, -0.3854,  ...,  0.0000,  0.0000,  1.6034],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.00

In [22]:
# Import model class and instantiate
mod = import_module(MODEL_MODULE)
ModelClass = getattr(mod, MODEL_CLASS_NAME)
model = ModelClass(**MODEL_PARAMS).eval()  # eval mode for graph drawing
print(model.__class__.__name__, "initialized.")

AttentionPoolingNetCurrentQuery initialized.


## Helpers

In [23]:
# Try to build a batch (x_current, x_past, mask_past, labels) in a few common cases
def extract_batch(dataset_obj, batch_size):
    """Return a tuple (x_current, x_past, mask_past, labels).
    Supports a Dataset, a tuple of tensors, or a dict.
    """
    if hasattr(dataset_obj, '__getitem__') and hasattr(dataset_obj, '__len__'):
        # Looks like a Dataset
        sample = dataset_obj[0]
        if isinstance(sample, (list, tuple)) and len(sample) >= 4:
            loader = torch.utils.data.DataLoader(dataset_obj, batch_size=batch_size, shuffle=False)
            batch = next(iter(loader))
            return batch  # expect (x_current, x_past, mask_past, labels)
        else:
            raise ValueError("Dataset sample does not look like (x_current, x_past, mask_past, labels)")

In [24]:
from typing import Dict


@torch.no_grad()
def compute_train_stats(train_loader, max_rows_for_median: int = 200_000):
    """
    Compute mean and median of features in training data.
    It considers only valid (non-masked) entries for the past features.
    Returns:
      {
        "mean_curr":   [D_curr],
        "mean_past":   [D_long],
        "median_curr": [D_curr]  (if collected),
        "median_past": [D_long]  (if collected),
        "has_median":  bool
      }
    """
    sum_curr = None
    sum_past = None
    n_curr = 0
    n_past = 0

    # buffers for median (on CPU)
    buf_curr = []
    buf_past = []
    total_rows_curr = 0
    total_rows_past = 0
    collect_median = True  # we will try; if it exceeds the limit, we fall back

    for batch in train_loader:
        # Support for datasets that return (x_curr, x_past, mask, y) or (x_curr, x_past, mask)
        if len(batch) == 4:
            x_curr, x_past, mask, _ = batch
        else:
            x_curr, x_past, mask = batch

        x_curr = x_curr.to("cpu", non_blocking=True)
        x_past = x_past.to("cpu", non_blocking=True)
        mask = mask.to("cpu", non_blocking=True).bool()

        B, T, D_long = x_past.shape
        D_curr = x_curr.shape[-1]

        # --- means ---
        # current
        sum_curr = (sum_curr if sum_curr is not None else torch.zeros(D_curr)) + x_curr.sum(dim=0)
        n_curr += x_curr.size(0)

        # past: use only valid rows
        valid_rows = x_past[mask]        # [N_valid, D_long]
        if valid_rows.numel() > 0:
            sum_past = (sum_past if sum_past is not None else torch.zeros(D_long)) + valid_rows.sum(dim=0)
            n_past += valid_rows.shape[0]

        # --- medianas (opcional, até limite de linhas) ---
        if collect_median:
            # current
            if total_rows_curr + x_curr.size(0) <= max_rows_for_median:
                buf_curr.append(x_curr)
                total_rows_curr += x_curr.size(0)
            else:
                collect_median = False
            # past
            if valid_rows.numel() > 0:
                if total_rows_past + valid_rows.shape[0] <= max_rows_for_median:
                    buf_past.append(valid_rows)
                    total_rows_past += valid_rows.shape[0]
                else:
                    collect_median = False

    mean_curr = sum_curr / max(n_curr, 1)
    mean_past = sum_past / max(n_past, 1)

    stats = {
        "mean_curr": mean_curr,
        "mean_past": mean_past,
        "has_median": False,
    }

    if collect_median and len(buf_curr) > 0 and len(buf_past) > 0:
        curr_mat = torch.cat(buf_curr, dim=0)          # [N_curr, D_curr]
        past_mat = torch.cat(buf_past, dim=0)          # [N_past, D_long]
        median_curr = torch.quantile(curr_mat, 0.5, dim=0)   # [D_curr]
        median_past = torch.quantile(past_mat, 0.5, dim=0)   # [D_long]
        stats.update({
            "median_curr": median_curr,
            "median_past": median_past,
            "has_median": True
        })

    return stats

def make_baselines(x_curr, x_past, mask, strategy="zeros", stats=None):
    """
    Create baseline tensors for current and past features according to strategy.
    Args:
      x_curr: [B, D_curr] tensor of current features
      x_past: [B, T, D_long] tensor of past features
      mask:   [B, T] boolean tensor indicating valid past steps
      strategy: "zeros", "means", or "medians"
      stats: precomputed statistics dict from compute_train_stats (required for "means" or "medians")
    Returns:
      base_curr: [B, D_curr] baseline tensor for current features
      base_past: [B, T, D_long] baseline tensor for past features
    """
    if strategy == "zeros":
        base_curr = torch.zeros_like(x_curr)
        base_past = torch.zeros_like(x_past)

    elif strategy == "means":
        assert stats is not None and "mean_curr" in stats and "mean_past" in stats, \
            "Passe stats com mean_curr/mean_past para strategy='means'."
        base_curr = stats["mean_curr"].to(x_curr).expand_as(x_curr).clone()
        base_past = stats["mean_past"].to(x_past).expand_as(x_past).clone()

    elif strategy == "medians":
        assert stats is not None and stats.get("has_median", False), \
            "Medianas não disponíveis (a coleta pode ter sido desativada por limite)."
        base_curr = stats["median_curr"].to(x_curr).expand_as(x_curr).clone()
        base_past = stats["median_past"].to(x_past).expand_as(x_past).clone()

    else:
        raise ValueError(f"Estratégia de baseline desconhecida: {strategy}")

    # Zerar baseline nos passos preenchidos (padding) segundo a máscara
    base_past = base_past.masked_fill(~mask.unsqueeze(-1), 0.0)
    return base_curr, base_past

def forward_for_attr(model, x_curr, x_past, mask):
    out = model(x_current=x_curr, x_past=x_past, mask_past=mask)
    if isinstance(out, tuple):
        logits, _ = out
    else:
        logits = out
    return logits

def save_training_stats(stats: Dict[str, torch.Tensor], out_json_path: str):
    import json
    serializable_stats = {}
    for k, v in stats.items():
        if isinstance(v, torch.Tensor):
            serializable_stats[k] = v.cpu().tolist()
        else:
            serializable_stats[k] = v
    with open(out_json_path, "w") as f:
        json.dump(serializable_stats, f, indent=2)
    print(f"Training data stats saved to: {out_json_path}")

## Feature Importance with Captum

In [16]:
from captum.attr import IntegratedGradients

def global_feature_importance(
    model,
    train_loader,
    test_loader,
    device,
    n_steps: int = 32,
    baseline_strategy: str = "means",  # "zeros" | "means" | "medians"
    internal_batch_size: int = 64
):
    model.eval()

    # 1) Estatísticas do train, se necessário
    stats = None
    if baseline_strategy in ["means", "medians"]:
        stats = compute_train_stats(train_loader, max_rows_for_median=200_000)
        print("mean_curr shape:", stats["mean_curr"].shape)
        print("mean_past shape:", stats["mean_past"].shape)
        print("has_median:", stats["has_median"])

    # 2) IG: forward wrapper recebe (x_curr, x_past, mask)
    ig = IntegratedGradients(lambda x_c, x_p, m: forward_for_attr(model, x_c, x_p, m))

    sum_abs_curr = None         # [D_curr]
    sum_abs_past_feat = None    # [D_long]
    sum_abs_time = None         # [T]
    n_samples = 0

    for batch in test_loader:
        # Suporta datasets que retornam 3 ou 4 itens
        if len(batch) == 4:
            x_curr, x_past, mask, y = batch
        else:
            x_curr, x_past, mask = batch
        mask = mask.bool()

        # Devices & grads (IG precisa de grad nos inputs)
        x_curr = x_curr.to(device).requires_grad_(True)
        x_past = x_past.to(device).requires_grad_(True)
        mask   = mask.to(device)
        if mask.dim() == 3 and mask.shape[-1] == 1:
            mask = mask.squeeze(-1)  # [B,T]

        # 3) Baselines por batch (shape igual ao do batch corrente)
        base_curr, base_past = make_baselines(
            x_curr, x_past, mask,
            strategy=baseline_strategy,
            stats=stats
        )

        # 4) IG por batch (atribui só em x_curr e x_past; mask vai como arg extra)
        attr_curr, attr_past = ig.attribute(
            inputs=(x_curr, x_past),
            baselines=(base_curr, base_past),
            additional_forward_args=(mask,),
            target=None,                 # binário: único logit
            n_steps=n_steps,
            internal_batch_size=internal_batch_size
        )

        # 5) Agregações (|.| e soma no batch)
        b_curr = attr_curr.abs().sum(dim=0)            # [D_curr]
        b_past_feat = attr_past.abs().sum(dim=(0, 1))  # [D_long]
        b_time = attr_past.abs().sum(dim=(0, 2))       # [T]

        sum_abs_curr = b_curr if sum_abs_curr is None else sum_abs_curr + b_curr
        sum_abs_past_feat = b_past_feat if sum_abs_past_feat is None else sum_abs_past_feat + b_past_feat
        sum_abs_time = b_time if sum_abs_time is None else sum_abs_time + b_time

        n_samples += x_curr.size(0)

    # 6) Médias por amostra
    mean_curr = sum_abs_curr / max(n_samples, 1)
    mean_past_feat = sum_abs_past_feat / max(n_samples, 1)
    mean_time = sum_abs_time / max(n_samples, 1)

    return mean_curr, mean_past_feat, mean_time

## Global Feature Importance

In [None]:
mean_curr, mean_past_feat, mean_time = global_feature_importance(
    model, train_loader, test_loader, device
)

# Top-k features atuais
k = 10
top_vals, top_idx = torch.topk(mean_curr, k)
print("Top features (x_curr):")
for i, (v, idx) in enumerate(zip(top_vals, top_idx)):
    feature_name = current_feat_cols[idx.item()] if current_feat_cols is not None else f"Feature {idx.item()}"
    print(f"{i+1:2d}. {feature_name} | Importance: {v.item():.4f}")

k = len(longitudinal_feat_cols)
# Top-k features históricas
top_vals_p, top_idx_p = torch.topk(mean_past_feat, k)
print("\nTop features (x_past):")
for i, (v, idx) in enumerate(zip(top_vals_p, top_idx_p)):
    feature_name = longitudinal_feat_cols[idx.item()] if longitudinal_feat_cols is not None else f"Feature {idx.item()}"
    print(f"{i+1:2d}. {feature_name} | Importance: {v.item():.4f}")

mean_curr shape: torch.Size([19])
mean_past shape: torch.Size([8])
has_median: True
Top features (x_curr):
 1. LOG_NUM_DRUGS | Importance: 0.0426
 2. LOG_HOSPITALIZATION_DAYS | Importance: 0.0306
 3. LOG_PARTICIPATION_DAYS | Importance: 0.0204
 4. NUM_PROCEDURES | Importance: 0.0176
 5. HAS_CONGESTIVE_HF | Importance: 0.0175
 6. CHARLSON_INDEX | Importance: 0.0144
 7. HAS_DIABETES | Importance: 0.0128
 8. AGE | Importance: 0.0100
 9. LOG_DAYS_IN_ICU | Importance: 0.0092
10. DISCHARGE_LOCATION_HOME | Importance: 0.0073

Top features (x_past):
 1. NUM_PROCEDURES | Importance: 0.0114
 2. DISCHARGE_LOCATION_POST_ACUTE_CARE | Importance: 0.0089
 3. CHARLSON_INDEX | Importance: 0.0086
 4. LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION | Importance: 0.0078
 5. LOG_NUM_DRUGS | Importance: 0.0074
 6. LOG_DAYS_IN_ICU | Importance: 0.0067
 7. LOG_HOSPITALIZATION_DAYS | Importance: 0.0042
 8. ADMISSION_TYPE_ELECTIVE | Importance: 0.0012


In [41]:
mean_curr

tensor([0.0306, 0.0092, 0.0144, 0.0426, 0.0176, 0.0204, 0.0128, 0.0051, 0.0175,
        0.0052, 0.0073, 0.0100, 0.0042, 0.0024, 0.0047, 0.0054, 0.0015, 0.0071,
        0.0055], dtype=torch.float64, grad_fn=<DivBackward0>)

In [38]:
import numpy as np
import plotly.graph_objects as go

y = mean_time.detach().cpu().numpy().ravel()
x = np.arange(1, len(y) + 1)

fig = go.Figure(go.Scatter(x=x, y=y, mode="lines+markers", marker=dict(size=6)))
fig.update_layout(
    title="Importance over time",
    xaxis_title="Step",
    yaxis_title="Mean absolute importance",
    template="plotly_white",
    width=800,
    height=400,
)
fig.show()

## Global Feature Importance - 4 past hospitalizations

In [57]:
from torch.utils.data import Subset

def filter_full_mask(dataset):
    indices_full_mask = []

    for idx, sample in enumerate(dataset):
        try:
            x_curr, x_past, mask, label = sample
        except Exception:
            # skip samples with unexpected format
            continue

        # normalize mask shape to [T]
        if mask.dim() == 3 and mask.shape[-1] == 1:
            mask = mask.squeeze(-1)
        mask = mask.bool()

        # check if all time steps are valid
        if mask.numel() > 0 and mask.all():
            indices_full_mask.append(idx)

    print(f"Found {len(indices_full_mask)} samples with all-true mask.")
    if indices_full_mask:
        print("Sample indices (preview):", indices_full_mask[:20])

    # create and return filtered subset
    return Subset(dataset, indices_full_mask)


In [60]:
full_mask_test_dataset = filter_full_mask(test_dataset_obj)
full_mask_train_dataset = filter_full_mask(train_dataset_obj)

Found 144 samples with all-true mask.
Sample indices (preview): [7, 8, 36, 48, 49, 63, 78, 91, 126, 157, 161, 162, 171, 172, 173, 202, 203, 204, 205, 206]
Found 693 samples with all-true mask.
Sample indices (preview): [49, 50, 51, 52, 80, 114, 115, 116, 117, 134, 143, 144, 145, 146, 147, 148, 197, 198, 213, 214]


In [61]:
full_mask_test_loader = torch.utils.data.DataLoader(
    full_mask_test_dataset, batch_size=model_cfg["batch_size"], shuffle=False
)
full_mask_train_dataset_loader = torch.utils.data.DataLoader(
    full_mask_train_dataset, batch_size=model_cfg["batch_size"], shuffle=False
)

In [62]:
mean_curr, mean_past_feat, mean_time = global_feature_importance(
    model, full_mask_train_dataset_loader, full_mask_test_loader, device
)
print("\nGlobal feature importance (full-mask samples only):")
k = 10
top_vals, top_idx = torch.topk(mean_curr, k)
print("Top features (x_curr):")
for i, (v, idx) in enumerate(zip(top_vals, top_idx)):
    feature_name = current_feat_cols[idx.item()] if current_feat_cols is not None else f"Feature {idx.item()}"
    print(f"{i+1:2d}. {feature_name} | Importance: {v.item():.4f}")

k = len(longitudinal_feat_cols)
# Top-k longitudinal features
top_vals_p, top_idx_p = torch.topk(mean_past_feat, k)
print("\nTop features (x_past):")
for i, (v, idx) in enumerate(zip(top_vals_p, top_idx_p)):
    feature_name = longitudinal_feat_cols[idx.item()] if longitudinal_feat_cols is not None else f"Feature {idx.item()}"
    print(f"{i+1:2d}. {feature_name} | Importance: {v.item():.4f}")

mean_curr shape: torch.Size([19])
mean_past shape: torch.Size([8])
has_median: True

Global feature importance (full-mask samples only):
Top features (x_curr):
 1. LOG_HOSPITALIZATION_DAYS | Importance: 0.0258
 2. NUM_PROCEDURES | Importance: 0.0220
 3. CHARLSON_INDEX | Importance: 0.0192
 4. LOG_NUM_DRUGS | Importance: 0.0185
 5. HAS_CONGESTIVE_HF | Importance: 0.0164
 6. HAS_DIABETES | Importance: 0.0121
 7. INSURANCE_MEDICAID | Importance: 0.0100
 8. ETHNICITY_BLACK | Importance: 0.0090
 9. ETHNICITY_WHITE | Importance: 0.0082
10. LOG_DAYS_IN_ICU | Importance: 0.0079

Top features (x_past):
 1. NUM_PROCEDURES | Importance: 0.0190
 2. DISCHARGE_LOCATION_POST_ACUTE_CARE | Importance: 0.0149
 3. LOG_DAYS_IN_ICU | Importance: 0.0121
 4. CHARLSON_INDEX | Importance: 0.0108
 5. LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION | Importance: 0.0094
 6. LOG_NUM_DRUGS | Importance: 0.0089
 7. LOG_HOSPITALIZATION_DAYS | Importance: 0.0074
 8. ADMISSION_TYPE_ELECTIVE | Importance: 0.0007


## Feature Importance per Sample

In [25]:
import os
import json
import torch

if os.path.exists(TRAINING_DATA_STATS_PATH):
    with open(TRAINING_DATA_STATS_PATH, "r") as f:
        raw = json.load(f)
    stats = {}
    for k, v in raw.items():
        # convert lists back to tensors when appropriate
        if isinstance(v, list):
            stats[k] = torch.tensor(v)
        else:
            stats[k] = v
    print(f"Loaded training stats from: {TRAINING_DATA_STATS_PATH}")
else:
    print(f"No training stats file found at: {TRAINING_DATA_STATS_PATH}. Will compute stats.")
    # If you want means/medians:
    stats = compute_train_stats(train_loader, max_rows_for_median=200_000)
    save_training_stats(stats, TRAINING_DATA_STATS_PATH)

Loaded training stats from: /workspaces/msc-thesis-recurrent-health-modeling/_models/mimic/deep_learning/training_stats.json


In [26]:
# get a batch to explain
batch_example = extract_batch(test_dataset_obj, batch_size=8)
x_curr_b, x_past_b, mask_b = batch_example[:3]  #

# (Optional) layer to split history vs current — for your AttentionPoolingNet:
fc1_layer = model.classifier_head[0]  # nn.Linear

In [27]:
from recurrent_health_events_prediction.model.explainers import explain_deep_learning_model_feat

df_curr, df_past, df_split  = explain_deep_learning_model_feat(
    model,
    x_curr_b[2],
    x_past_b[2],
    mask_b[2],
    current_feat_cols,
    longitudinal_feat_cols,
    layer_for_split=fc1_layer,
    baseline_strategy="means",
    stats=stats,
    n_steps=64,
    internal_batch_size=16,
)

In [21]:
df_curr.sort_values(by="importance", ascending=False).head(10)

Unnamed: 0,sample_idx,feature_name,importance
5,0,LOG_PARTICIPATION_DAYS,0.057342
1,0,LOG_DAYS_IN_ICU,0.046817
7,0,HAS_COPD,0.031268
8,0,HAS_CONGESTIVE_HF,0.028045
11,0,AGE,0.026688
16,0,ETHNICITY_HISPANIC,0.026623
4,0,NUM_PROCEDURES,0.015564
10,0,DISCHARGE_LOCATION_HOME,0.015542
0,0,LOG_HOSPITALIZATION_DAYS,0.013274
14,0,ETHNICITY_WHITE,0.012395


In [22]:
df_past.sort_values(by="importance", ascending=False).head(10)

Unnamed: 0,sample_idx,feature_name,importance
0,0,LOG_HOSPITALIZATION_DAYS,0.039259
1,0,LOG_DAYS_IN_ICU,0.038828
3,0,LOG_NUM_DRUGS,0.034356
4,0,NUM_PROCEDURES,0.013395
7,0,LOG_DAYS_UNTIL_NEXT_HOSPITALIZATION,0.0114
5,0,DISCHARGE_LOCATION_POST_ACUTE_CARE,0.004105
6,0,ADMISSION_TYPE_ELECTIVE,0.002165
2,0,CHARLSON_INDEX,0.001112


In [28]:
df_split.iloc[0].to_dict()

{'sample_idx': 0.0,
 'past_attribution': 0.08032890434862353,
 'current_attribution': 0.33738889031117253}