## Load model

In [None]:
# --- Load CBM model with relative path (DANCE repo) ---------------------------
from __future__ import annotations
import argparse, json, sys
from pathlib import Path
import torch

def find_repo_root(marker="CBM_training"):
    cwd = Path.cwd()
    for p in [cwd, *cwd.parents]:
        if (p / marker).exists():
            return p
repo_root = find_repo_root("CBM_training")
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))


from CBM_training.model import cbm  # from CBM_training/model/cbm.py
from CBM_training.model import plots

# --- Configure load_dir (relative to Experiments/) ----------------------------
cfg = argparse.Namespace()
cfg.load_dir = Path("../result/Penn_Action_result/penn-action_Penn_Action_motion_label+Penn_action_object_concept+Penn_action_scene_concept")

args_path = cfg.load_dir / "args.txt"
if not args_path.exists():
    raise FileNotFoundError(f"Missing args.txt at: {args_path.resolve()}")

# Load saved args
with args_path.open("r", encoding="utf-8") as f:
    args = argparse.Namespace(**json.load(f))

# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Backbone:", getattr(args, "backbone", "<unknown>"))

# Load model
model = cbm.load_cbm_dynamic(str(cfg.load_dir), device, args)
print("Model loaded successfully on", device)


In [4]:
from pathlib import Path
import os

# Load class list
cls_file = Path(args.video_anno_path) / "class_list.txt"
with cls_file.open("r", encoding="utf-8") as f:
    classes = f.read().splitlines()

def load_all_concepts_by_type(load_dir: Path, train_mode: list[str]):
    """Load concepts.txt for each concept type in train_mode."""
    concept_dict = {}
    for concept_type in train_mode:
        concept_path = Path(load_dir) / concept_type / "concepts.txt"
        if concept_path.exists():
            with concept_path.open("r", encoding="utf-8") as f:
                concepts = f.read().splitlines()
            concept_dict[concept_type] = concepts
    return concept_dict

# Example usage
concept_dict = load_all_concepts_by_type(cfg.load_dir, args.train_mode)

if len(args.train_mode) < 2:
     # Single-mode: use the concept file of the corresponding type
    tm = args.train_mode[0]
    concepts = concept_dict.get(tm, [])

    if "class_attributes" in Path(args.pose_label).name:
        # UCF101 attributes
        attr_path = Path("../dataset/UCF101/class_attributes/attribute.txt")
        with attr_path.open("r", encoding="utf-8") as f:
            concepts = f.read().splitlines()


     # Special case: attribute + object/scene combination
    elif "attr_object_scene" in str(cfg.load_dir):
        scene_obj_path = Path("../dataset/UCF101/class_attributes_only_object_place/attribute.txt")
        with scene_obj_path.open("r", encoding="utf-8") as f:
            scene_object_concepts = f.read().splitlines()
        head_len = max(0, len(concepts) - len(scene_object_concepts))
        concepts = concepts[:head_len] + scene_object_concepts

else:
    # Multi-mode: aggregated concepts are always generated
    agg_path = Path(cfg.load_dir) / "aggregated" / "concepts.txt"
    with agg_path.open("r", encoding="utf-8") as f:
        concepts = f.read().splitlines()

assert len(concepts) == model.final.weight.shape[1], "Concept count mismatch"
print(f"Concept number: {len(concepts)}")
print(concepts[:5])


Concept number: 147
['0', '1', '2', '3', '4']


## Load video dataset

In [5]:
from CBM_training.video_dataloader import datasets

# Build validation dataset (for evaluation)
val_video_dataset, _ = datasets.build_dataset(False, False, args)

# Build another dataset instance for visualization
val_visualize_dataset, _ = datasets.build_dataset(False, False, args)
val_visualize_dataset.visualize = True
val_visualize_dataset.no_aug = True

# Extract target labels
val_targets = val_video_dataset.label_array
val_y = torch.LongTensor(val_targets)


Number of the class = 15
Number of the class = 15


In [6]:
# --- Extract and cache concept activations for the validation set ------------
import os
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np 

device = "cuda" if torch.cuda.is_available() else "cpu"

# I/O config (save next to the loaded experiment directory; no hardcoded paths)
save_path = Path(cfg.load_dir) / "all_concept_activations.pt"

# DataLoader config (portable across OS/CPUs)
batch_size = 32
num_workers = min(4, (os.cpu_count() or 1))  # keep conservative for cross-platform
pin_memory = (device == "cuda")

if save_path.exists():
    print(f"[info] File already exists at: {save_path}")
    all_activations = torch.load(save_path, map_location="cpu")
else:
    all_activations = []
    loader = DataLoader(
        val_video_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    with torch.inference_mode():
        for videos, _labels in tqdm(loader, desc="Extracting concept activations"):
            videos = videos.to(device, non_blocking=pin_memory)
            outputs, concept_act = model(videos)  # `concept_act` shape: (N, C)
            all_activations.append(concept_act.cpu())

    # Concatenate to a single (N_total, C) tensor and save
    all_activations = torch.cat(all_activations, dim=0)
    torch.save(all_activations, save_path)
    print(f"[info] Saved all_activations to: {save_path}")

print("[info] all_activations shape:", tuple(all_activations.shape))

Extracting concept activations: 100%|██████████| 34/34 [00:58<00:00,  1.71s/it]

[info] Saved all_activations to: ../result/Penn_Action_result/penn-action_Penn_Action_motion_label+Penn_action_object_concept+Penn_action_scene_concept/all_concept_activations.pt
[info] all_activations shape: (1067, 147)





In [None]:
# --- Load backbone features and classifier weights ---------
from pathlib import Path
import os
import torch

# Preconditions (these are defined in earlier cells)
assert "args" in globals(), "`args` is not defined (loaded from args.txt)."
assert "cfg" in globals() and hasattr(cfg, "load_dir"), "`cfg.load_dir` is not configured."
device = "cuda" if torch.cuda.is_available() else "cpu"

def resolve_backbone_feature_paths(args, repo_root: Path | None = None):
    """
    Resolve train/val backbone feature file paths without hardcoded absolute paths.
    Priority:
      1) Use args.backbone_features if it exists.
      2) Otherwise, infer from a conventional layout:
         CBM_training/results/Features/{DATASET}/{BACKBONE}/
           {DATASET}_train_{BACKBONE}.pt
           {DATASET}_val_{BACKBONE}.pt
    """
    # Try provided path in args (may be absolute or relative to CWD)
    if hasattr(args, "backbone_features"):
        p_train = Path(args.backbone_features)
        if p_train.exists():
            p_val = Path(
                str(p_train).replace(f"{args.data_set}_train", f"{args.data_set}_val")
            )
            if not p_val.exists():
                raise FileNotFoundError(f"Could not infer val feature path from: {p_train}")
            return p_train, p_val

    # Fallback: infer a conventional path within the repo
    # Attempt to use the same repo_root detection done earlier, otherwise walk up.
    if repo_root is None:
        from pathlib import Path as _P
        cwd = _P.cwd()
        for p in [cwd, *cwd.parents]:
            if (p / "CBM_training").exists():
                repo_root = p
                break
        if repo_root is None:
            raise RuntimeError("Cannot find repository root containing 'CBM_training'.")

    dataset = getattr(args, "data_set", None)
    backbone = getattr(args, "backbone", None) or getattr(args, "backbone_name", None)
    if dataset is None or backbone is None:
        raise AttributeError("`args` must include `data_set` and `backbone` (or `backbone_name`).")

    # Example file names: UCF101_train_vmae_vit_base_patch16_224.pt
    fname_train = f"{dataset}_train_{backbone}.pt"
    fname_val   = f"{dataset}_val_{backbone}.pt"

    # Some repos add another level (e.g., 'vmae') between dataset and file; try a few common layouts.
    candidates = [
        repo_root / "CBM_training" / "results" / "Features" / dataset / backbone / fname_train,
        repo_root / "CBM_training" / "results" / "Features" / dataset / fname_train,
    ]

    p_train = next((p for p in candidates if p.exists()), None)
    if p_train is None:
        raise FileNotFoundError(
            f"Backbone train features not found. Tried:\n- " + "\n- ".join(map(str, candidates))
        )

    # Derive val path alongside train
    p_val = p_train.with_name(fname_val)
    if not p_val.exists():
        # Also try parallel directory forms if needed
        alt_candidates = [
            p_train.parent / fname_val,
            repo_root / "CBM_training" / "results" / "Features" / dataset / backbone / fname_val,
            repo_root / "CBM_training" / "results" / "Features" / dataset / fname_val,
        ]
        p_val = next((p for p in alt_candidates if p.exists()), None)
        if p_val is None:
            raise FileNotFoundError(
                f"Backbone val features not found. Tried:\n- " + "\n- ".join(map(str, alt_candidates))
            )

    return p_train, p_val

# Resolve feature paths (portable, no hardcoded absolute paths)
train_feat_path, val_feat_path = resolve_backbone_feature_paths(args)

# Load features to the selected device
backbone_features = torch.load(train_feat_path, map_location=device).float()
val_backbone_features = torch.load(val_feat_path, map_location=device).float()

# --- Load classifier weights/bias and projection stats -----------------------
load_dir = Path(cfg.load_dir)

if hasattr(args, "train_mode") and len(args.train_mode) == 1:
    load_sub_dir = load_dir / args.train_mode[0]
    W_c = torch.load(load_sub_dir / "W_c.pt", map_location=device)
    W_g = torch.load(load_sub_dir / "W_g.pt", map_location=device)
    b_g = torch.load(load_sub_dir / "b_g.pt", map_location=device)
    proj_mean = torch.load(load_sub_dir / "proj_mean.pt", map_location=device)
    proj_std  = torch.load(load_sub_dir / "proj_std.pt", map_location=device)
else:
    # Aggregated (multi-mode) weights
    agg_dir = load_dir / "aggregated"
    W_g = torch.load(agg_dir / "W_g.pt", map_location=device)
    b_g = torch.load(agg_dir / "b_g.pt", map_location=device)
    # If you need W_c/proj stats for aggregated runs, add them here when available.

print("[info] Loaded features and classifier parameters.")

In [8]:
# --- Analyze cumulative activations across all concepts ----------------------
import torch

# Preconditions
assert "all_activations" in globals(), "`all_activations` is not defined."
assert all_activations.ndim == 2, "Expected all_activations shape: (N, C)."

# (1) Sum activations across all samples → (C,)
concept_activation_sum = all_activations.sum(dim=0)

# (2) Find top-k concepts with largest cumulative activations
topk_vals, topk_indices = torch.topk(concept_activation_sum, k=5)

print("🔝 Top 5 most activated concepts across dataset:")
for i, idx in enumerate(topk_indices):
    val = concept_activation_sum[idx].item()
    # Fallback if concept names are unavailable
    if "concepts" in globals() and len(concepts) > int(idx):
        concept_name = concepts[idx] if isinstance(concepts[idx], str) else f"Concept {idx}"
    else:
        concept_name = f"Concept {idx}"
    print(f"{i+1}. {concept_name:<25} | Total activation: {val:+.4f}")

🔝 Top 5 most activated concepts across dataset:
1. Boot camp class           | Total activation: +191.9189
2. Workout clothing          | Total activation: +171.8232
3. CrossFit box              | Total activation: +171.7858
4. Resistance bands          | Total activation: +171.3524
5. Personal training session | Total activation: +171.1149


In [None]:
# --- Analyze per-class errors: concepts & confused classes -------------------
import torch
from typing import List, Dict, Any, Optional, Tuple

def analyze_wrong_concepts_and_confusions(
    class_idx: int,
    prediction: torch.Tensor,   # (N, C) logits/probs
    val_y: torch.Tensor,        # (N,)
    val_c: torch.Tensor,        # (N, num_concepts) concept activations
    concepts: List[str],
    classes: List[str],
    topk_concept: int = 10,
    topk_confuse: int = 5,
    threshold: float = 0.0,
    sort_by: str = "mean",      # "mean" | "ratio" | "absmean"
    return_stats: bool = False,
) -> Optional[Dict[str, Any]]:
    """
    Print diagnostics for a given class: accuracy among its samples, the most active
    concepts in wrong predictions, and which other classes it is most often confused with.

    Args:
        class_idx: Class index to analyze.
        prediction: (N, C) tensor of logits or probabilities.
        val_y: (N,) ground-truth class indices.
        val_c: (N, num_concepts) concept activations aligned with samples in `prediction`.
        concepts: List of concept names (len = num_concepts).
        classes: List of class names (len = C).
        topk_concept: How many concepts to display from wrong predictions.
        topk_confuse: How many confused classes to display.
        threshold: Concept is considered "active" if activation > threshold.
        sort_by: Ranking for concepts among wrong predictions:
                 - "mean": descending by mean activation
                 - "ratio": descending by active ratio (fraction of wrong samples > threshold)
                 - "absmean": descending by mean absolute activation
        return_stats: If True, return a dictionary of computed stats instead of None.

    Returns:
        None, or a dict with stats if `return_stats=True`.
    """

    # Move to CPU for consistent printing/ops
    prediction = prediction.detach().cpu()
    val_y = val_y.detach().cpu()
    val_c = val_c.detach().cpu()

    # Predicted labels
    pred_labels = prediction.argmax(dim=1)

    # Subset: samples of the target class
    target_mask = (val_y == class_idx)
    num_total = int(target_mask.sum().item())

    if num_total == 0:
        print(f"\nClass {class_idx} ({classes[class_idx]}): no samples in validation set.")
        return None

    correct_mask = target_mask & (pred_labels == val_y)
    num_correct = int(correct_mask.sum().item())
    class_acc = num_correct / num_total if num_total > 0 else float("nan")

    print(f"\nClass {class_idx} ({classes[class_idx]}) Accuracy: "
          f"{class_acc*100:.2f}% ({num_correct}/{num_total})")

    wrong_mask = target_mask & (pred_labels != val_y)
    num_wrong = int(wrong_mask.sum().item())
    if num_wrong == 0:
        print(f"No wrong predictions for class {class_idx} ({classes[class_idx]}).")
        if return_stats:
            return {
                "class_idx": class_idx,
                "class_name": classes[class_idx],
                "num_total": num_total,
                "num_correct": num_correct,
                "num_wrong": 0,
                "accuracy": class_acc,
                "top_concepts": [],
                "confusions": [],
            }
        return None

    # Concepts among wrong predictions
    wrong_concepts = val_c[wrong_mask]  # (num_wrong, num_concepts)
    mean_activation = wrong_concepts.mean(dim=0)            # (num_concepts,)
    mean_abs_activation = wrong_concepts.abs().mean(dim=0)  # (num_concepts,)
    active_count = (wrong_concepts > threshold).sum(dim=0)  # (num_concepts,)
    active_ratio = active_count.float() / max(num_wrong, 1) # (num_concepts,)

    # Sorting strategy
    if sort_by == "mean":
        sort_scores = mean_activation
    elif sort_by == "ratio":
        sort_scores = active_ratio
    else:  # "absmean"
        sort_scores = mean_abs_activation

    # Handle potential NaNs (e.g., if num_wrong=0 which we already guarded)
    sort_scores = torch.nan_to_num(sort_scores, nan=0.0)

    # Top-K concepts
    k_concepts = min(topk_concept, sort_scores.numel())
    top_indices = torch.topk(sort_scores, k=k_concepts, largest=True).indices.tolist()

    print(f"\nTop {k_concepts} concepts among wrong predictions (sorted by {sort_by}):\n")
    top_concepts_rows: List[Tuple[int, float, float, float]] = []  # (idx, mean, absmean, ratio)
    for rank, idx in enumerate(top_indices, 1):
        cname = concepts[idx] if isinstance(concepts[idx], str) else f"Concept {idx}"
        mean_val = float(mean_activation[idx].item())
        absmean_val = float(mean_abs_activation[idx].item())
        ratio_val = float(active_ratio[idx].item()) * 100.0
        cnt_val = int(active_count[idx].item())
        top_concepts_rows.append((idx, mean_val, absmean_val, ratio_val))
        print(f"{rank:>2}. {cname:<30} "
              f"| mean: {mean_val:+.4f} | |mean|: {absmean_val:.4f} "
              f"| active: {cnt_val}/{num_wrong} ({ratio_val:.1f}%)")

    # Confused classes
    wrong_preds = pred_labels[wrong_mask]
    unique_preds, counts = torch.unique(wrong_preds, return_counts=True)
    pairs = sorted(zip(unique_preds.tolist(), counts.tolist()), key=lambda x: x[1], reverse=True)

    k_conf = min(topk_confuse, len(pairs))
    print(f"\nTop {k_conf} confused classes:\n")
    confusions_rows: List[Tuple[int, str, int, float]] = []  # (pred_idx, name, count, ratio%)
    for i, (pred_c, cnt) in enumerate(pairs[:k_conf], 1):
        name = classes[pred_c] if 0 <= pred_c < len(classes) else f"Class {pred_c}"
        ratio = (cnt / num_wrong) * 100.0
        confusions_rows.append((pred_c, name, int(cnt), float(ratio)))
        print(f"{i}. Predicted as {name} ({pred_c}): {cnt}/{num_wrong} times ({ratio:.1f}%)")

    if return_stats:
        return {
            "class_idx": class_idx,
            "class_name": classes[class_idx],
            "num_total": num_total,
            "num_correct": num_correct,
            "num_wrong": num_wrong,
            "accuracy": class_acc,
            "top_concepts": [
                {
                    "concept_idx": idx,
                    "concept_name": (concepts[idx] if isinstance(concepts[idx], str) else f"Concept {idx}"),
                    "mean": mean_val,
                    "absmean": absmean_val,
                    "active_ratio_pct": ratio_val,
                }
                for (idx, mean_val, absmean_val, ratio_val) in top_concepts_rows
            ],
            "confusions": [
                {
                    "pred_class_idx": pred_c,
                    "pred_class_name": name,
                    "count": cnt,
                    "ratio_pct": ratio,
                }
                for (pred_c, name, cnt, ratio) in confusions_rows
            ],
        }
    return None

In [None]:
# --- Inspect class-specific concept weights ----------------------------------
import torch

def print_class_concept_weights(model, concepts, classes, class_idx, threshold=0.1, top_k=None):
    """
    Print the most influential concept weights for a given class.

    Args:
        model: Trained CBM model with `final.weight` available.
        concepts (list[str]): List of concept names (length = num_concepts).
        classes (list[str]): List of class names (length = num_classes).
        class_idx (int): Index of the class to inspect.
        threshold (float): Only show concepts with |weight| > threshold.
        top_k (int | None): If set, only print the top-k highest-weight concepts.
    """

    weights = model.final.weight[class_idx]  # shape: (num_concepts,)
    weights_np = weights.detach().cpu().numpy()

    # Filter concepts by threshold
    filtered = [(j, weights_np[j]) for j in range(len(concepts)) if abs(weights_np[j]) > threshold]
    # Sort by absolute weight magnitude
    sorted_filtered = sorted(filtered, key=lambda x: abs(x[1]), reverse=True)

    if top_k is not None:
        sorted_filtered = sorted_filtered[:top_k]

    print(f"\nClass: {classes[class_idx]} (|weight| > {threshold})")
    for idx, w in sorted_filtered:
        cname = concepts[idx] if isinstance(concepts[idx], str) else f"Concept {idx}"
        print(f"{cname:<30} [{w:+.4f}]")

In [18]:
# --- Get Prediction -------------------------------------------------
val_c = all_activations.clone()

cls_layer = torch.nn.Linear(val_c.shape[1],len(classes)).to(args.device)
cls_layer.load_state_dict({"weight":W_g,"bias":b_g})
with torch.no_grad():
    prediction = cls_layer(val_c.cuda().detach())

# --- Run analysis for two classes --------------------------------------------
observed_class_idx = 0
confused_class_idx = 1

analyze_wrong_concepts_and_confusions(
    class_idx=observed_class_idx,
    prediction=prediction,
    val_y=val_y,
    val_c=val_c,
    concepts=concepts,
    classes=classes,
    topk_concept=40,
    topk_confuse=5,
    threshold=0.0
)

print_class_concept_weights(model, concepts, classes, observed_class_idx)
print_class_concept_weights(model, concepts, classes, confused_class_idx)


Class 0 (baseball_pitch) Accuracy: 100.00% (63/63)
No wrong predictions for class 0 (baseball_pitch).

Class: baseball_pitch (|weight| > 0.1)
a pitcher's mound              [+0.5091]
29                             [+0.5024]
11                             [+0.4553]
a baseball glove               [+0.3832]
35                             [+0.3724]
a catcher's mitt               [+0.3440]
a baseball uniform             [+0.2820]
73                             [+0.2512]
49                             [-0.2485]
47                             [+0.2307]
5                              [+0.2288]
Little league field            [+0.1693]
13                             [-0.1416]
Open field                     [+0.1339]
28                             [+0.1236]
Batting cage                   [+0.1200]
8                              [+0.1027]

Class: clean_and_jerk (|weight| > 0.1)
42                             [+1.0891]
59                             [+0.5870]
56                             [+0.336

In [20]:
# --- Intervention on classifier weights ------------
import torch

def resolve_concept_index(concept_idx, concepts):
    """Allow indices like 24 or strings like '24' or concept names."""
    if isinstance(concept_idx, int):
        return concept_idx
    if isinstance(concept_idx, str):
        # If it's a digit-like string, use it as numeric index
        if concept_idx.isdigit():
            return int(concept_idx)
        # Otherwise treat as concept name
        return concepts.index(concept_idx)
    raise TypeError("concept_idx must be int or str.")

def intervene_and_predict(
    val_c: torch.Tensor,     # (N, K_concept) concept activations
    W_g: torch.Tensor,       # (C_cls, K_concept) classifier weights
    b_g: torch.Tensor,       # (C_cls,) classifier bias
    concepts,                # list of concept names
    classes,                 # list of class names
    weight_up_class_idx: int,
    concept_idx,             # int or str ('24' or concept name)
    margin: float = 0.5,
    weight_down_class_idx: int | None = None,
    device: str | None = None,
):
    """
    Returns:
        original_prediction: logits with original W_g/b_g
        modified_prediction: logits after adding/subtracting margin on selected (class, concept)
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    val_c = val_c.to(device, non_blocking=True)
    W_g = W_g.to(device)
    b_g = b_g.to(device)

    # Resolve concept index robustly
    cidx = resolve_concept_index(concept_idx, concepts)

    # Base classifier layer
    base_cls = torch.nn.Linear(val_c.shape[1], W_g.shape[0], bias=True).to(device)
    with torch.no_grad():
        base_cls.weight.copy_(W_g)
        base_cls.bias.copy_(b_g)
        original_prediction = base_cls(val_c)

    # Modify W_g for intervention
    W_g_modified = W_g.clone()
    print(f"[intervene] concept = {concepts[cidx] if isinstance(concepts[cidx], str) else cidx}")

    # Decrease weight for a confusing class (optional)
    if weight_down_class_idx is not None and weight_down_class_idx >= 0:
        before = W_g_modified[weight_down_class_idx, cidx].item()
        W_g_modified[weight_down_class_idx, cidx] = before - margin
        after = W_g_modified[weight_down_class_idx, cidx].item()
        print(f"  ↓ {classes[weight_down_class_idx]}: {before:.4f} -> {after:.4f}")

    # Increase weight for the target class
    before = W_g_modified[weight_up_class_idx, cidx].item()
    W_g_modified[weight_up_class_idx, cidx] = before + margin
    after = W_g_modified[weight_up_class_idx, cidx].item()
    print(f"  ↑ {classes[weight_up_class_idx]}: {before:.4f} -> {after:.4f}")

    # Classifier with modified weights
    mod_cls = torch.nn.Linear(val_c.shape[1], W_g.shape[0], bias=True).to(device)
    with torch.no_grad():
        mod_cls.weight.copy_(W_g_modified)
        mod_cls.bias.copy_(b_g)
        modified_prediction = mod_cls(val_c)

    return original_prediction, modified_prediction

# --- Example usage ------------------------------------------------------------
weight_up_class_idx = 1
weight_down_class_idx = 2
concept_idx = "24"  # can be int(24), "24", or an actual concept name

original_pred, modified_pred = intervene_and_predict(
    val_c=val_c,                 # (N, K_concept)
    W_g=W_g,                     # (C_cls, K_concept)
    b_g=b_g,                     # (C_cls,)
    concepts=concepts,
    classes=classes,
    weight_up_class_idx=weight_up_class_idx,
    concept_idx=concept_idx,
    margin=0.5,
    weight_down_class_idx=weight_down_class_idx,
    device=None,                 # auto-selects cuda/cpu
)

[intervene] concept = 24
  ↓ pullup: 0.0000 -> -0.5000
  ↑ clean_and_jerk: 0.0000 -> 0.5000


In [21]:
# --- Compare predictions before/after intervention ------------------
import torch
from typing import List, Optional, Dict, Any

def analyze_prediction_changes(
    original_pred: torch.Tensor,   # (N, C) logits/scores before
    modified_pred: torch.Tensor,   # (N, C) logits/scores after
    val_y: torch.Tensor,           # (N,) ground-truth labels
    class_names: Optional[List[str]] = None,
    show_limit: Optional[int] = None,  # print at most this many changed samples
) -> Dict[str, Any]:
    """
    Print a compact summary of changes and return counts/indices as a dict.
    """
    # Argmax labels
    original_labels = original_pred.argmax(dim=1)
    modified_labels = modified_pred.argmax(dim=1)

    # Correctness masks
    matched_before = (original_labels == val_y)
    matched_after  = (modified_labels == val_y)

    # Change masks
    improved_idx = (~matched_before) & matched_after
    degraded_idx = matched_before & (~matched_after)
    changed_idx = improved_idx | degraded_idx

    # Indices
    indices = changed_idx.nonzero(as_tuple=False).flatten()
    total_changed = int(indices.numel())

    print(f"\n[Changed Samples Summary] (Total changed samples: {total_changed})\n")
    print(f"{'Δ':<2} {'ID':<6} {'GT':<20} {'Before':<20} {'After':<20}")
    print("-" * 75)

    limit = total_changed if show_limit is None else min(show_limit, total_changed)
    for idx in indices[:limit].tolist():
        gt = int(val_y[idx].item())
        orig = int(original_labels[idx].item())
        modf = int(modified_labels[idx].item())

        gt_name   = class_names[gt]   if class_names and 0 <= gt   < len(class_names) else str(gt)
        orig_name = class_names[orig] if class_names and 0 <= orig < len(class_names) else str(orig)
        modf_name = class_names[modf] if class_names and 0 <= modf < len(class_names) else str(modf)

        change_type = "+" if improved_idx[idx] else "-"
        print(f"{change_type:<2} {idx:<6} {gt_name:<20} {orig_name:<20} {modf_name:<20}")

    # Accuracy summary
    original_acc = float(matched_before.float().mean().item())
    modified_acc = float(matched_after.float().mean().item())

    print(f"\n[Accuracy Summary]")
    print(f" - Before Intervention: {original_acc*100:.2f}%")
    print(f" - After  Intervention: {modified_acc*100:.2f}%")
    print(f" - Change: {(modified_acc - original_acc)*100:+.2f}%")

    # Count summary
    cnt_before  = int(matched_before.sum().item())
    cnt_after   = int(matched_after.sum().item())
    cnt_improve = int(improved_idx.sum().item())
    cnt_degrade = int(degraded_idx.sum().item())

    print(f"\n[Detailed Count Summary]")
    print(f" - Correct before: {cnt_before} samples")
    print(f" - Correct after : {cnt_after} samples")
    print(f" - Improved      : {cnt_improve} samples")
    print(f" - Degraded      : {cnt_degrade} samples")

    return {
        "indices_changed": indices.tolist(),
        "indices_improved": improved_idx.nonzero(as_tuple=False).flatten().tolist(),
        "indices_degraded": degraded_idx.nonzero(as_tuple=False).flatten().tolist(),
        "accuracy_before": original_acc,
        "accuracy_after": modified_acc,
        "delta_accuracy": modified_acc - original_acc,
        "count_correct_before": cnt_before,
        "count_correct_after": cnt_after,
        "count_improved": cnt_improve,
        "count_degraded": cnt_degrade,
        "total_changed": total_changed,
    }

In [22]:
analyze_prediction_changes(original_pred.cpu(), modified_pred.cpu(), val_y.cpu(), class_names=classes)
print()
print("*"*100)
print()
analyze_wrong_concepts_and_confusions(
    class_idx=observed_class_idx,
    prediction=modified_pred,
    val_y=val_y,
    val_c=val_c,
    concepts=concepts,
    classes=classes,  # 클래스 이름 리스트
    topk_concept=10,
    topk_confuse=5,
    threshold=0.0
)


[Changed Samples Summary] (Total changed samples: 0)

Δ  ID     GT                   Before               After               
---------------------------------------------------------------------------

[Accuracy Summary]
 - Before Intervention: 98.03%
 - After  Intervention: 98.03%
 - Change: +0.00%

[Detailed Count Summary]
 - Correct before: 1046 samples
 - Correct after : 1046 samples
 - Improved      : 0 samples
 - Degraded      : 0 samples

****************************************************************************************************


Class 0 (baseball_pitch) Accuracy: 100.00% (63/63)
No wrong predictions for class 0 (baseball_pitch).
