# DER系アプローチの分析用ノートブック

In [1]:
from pathlib import Path
import os, sys, json, re
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# 例: 物理GPU1番だけを見せる
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

## 1) Project root / config / phase 設定

In [3]:
# === プロジェクトルートの指定 ===
# Notebook を CIDL-main の直下で開いているなら "." で OK
PROJECT_ROOT = Path("/home/kouyou/ContinualLearning/repexp/PyCIL").resolve()
os.chdir(PROJECT_ROOT)
sys.path.append(str(PROJECT_ROOT))

from trainer import _set_device  # 既存のヘルパをそのまま使う
from utils.data_manager import DataManager
from utils import factory

# === 使いたい設定ファイルと、どの phase を可視化するか ===
CONFIG_PATH = "exps/der_mu/baseline0/cifar100.json"   # 適宜変更
PHASE_ID    = 5                                         # ex) タスク3終了時のモデル → phase3.pkl

# === json 読み込み → args にする ===
with open(CONFIG_PATH) as f:
    args = json.load(f)

# device を training と同じ形式 (list of torch.device) に変換
_set_device(args)

print("model_name:", args["model_name"])
print("dataset   :", args["dataset"])
print("device    :", args["device"])

model_name: der-mu
dataset   : cifar100
device    : [device(type='cuda', index=0)]


## 2) Checkpoint path を trainer と同じ規則で組み立てる

In [4]:
from glob import glob

def _seed_to_str(seed):
    # trainer.py は args['seed'] をそのまま format しているので、list/int 両対応にしておく
    if isinstance(seed, (list, tuple)):
        return str(seed[0]) if len(seed) else "0"
    return str(seed)

def build_ckpt_dir(args, root=PROJECT_ROOT / "checkpoint"):
    
    """trainer / BaseLearner.save_checkpoint と同じ規則で checkpoint ディレクトリを作る"""
    init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
    log = "baseline" if "log" not in args else args["log"]
    
    ckpt_dir = "logs/{}/{}/{}/{}/{}/{}_{}_{}/".format(
        args["model_name"],
        log,
        args["dataset"],
        init_cls,
        args["increment"],
        args["prefix"], args["seed"][0], args["convnet_type"],
    )
    return ckpt_dir

ckpt_dir  = build_ckpt_dir(args)
ckpt_path = f"{ckpt_dir}/phase_{PHASE_ID}.pkl"

print("ckpt_dir :", ckpt_dir)
print("ckpt_path:", ckpt_path)

# # もし exists が False のときは、pattern で探す fallback も書いておくと楽
# if not ckpt_path.exists():
#     cand = glob(str(ckpt_dir / f"phase{PHASE_ID}*.pkl"))
#     print("fallback candidates:", cand)


ckpt_dir : logs/der-mu/baseline0/cifar100/0/10/reproduce_1993_resnet32/
ckpt_path: logs/der-mu/baseline0/cifar100/0/10/reproduce_1993_resnet32//phase_5.pkl


## 3) DataManager と model を作って checkpoint をロード（DER/TagFex など拡張系対応）

In [17]:
def infer_convnet_count_from_state_dict(state_dict):
    # keys like: 'convnets.0.xxx', 'convnets.1.xxx', ...
    idxs = []
    pat = re.compile(r"^convnets\.(\d+)\.")
    for k in state_dict.keys():
        m = pat.match(k)
        if m:
            idxs.append(int(m.group(1)))
    return (max(idxs) + 1) if idxs else 0

def classes_at_task(k, args, total_classnum):
    # task k の終了時点での total_classes（BaseLearnerの挙動に合わせる）
    init_cls = args["init_cls"]
    inc = args["increment"]
    num = init_cls + k * inc
    return min(num, total_classnum)

def build_network_skeleton_for_ckpt(model, state_dict, data_manager, args, phase_id):
    """state_dict をロード可能な形に network を拡張しておく。
    DERNet/TagFexNet 系は update_fc を task 回数ぶん呼ぶ必要がある。
    """
    net = model._network
    # DataParallel だと update_fc が面倒なので notebook では使わない前提
    if isinstance(net, torch.nn.DataParallel):
        net = net.module
        model._network = net

    convnet_count = infer_convnet_count_from_state_dict(state_dict)
    if convnet_count == 0:
        # ふつうはありえないが、念のため
        convnet_count = phase_id + 1

    # 最終クラス数は fc.weight の out_features で確定できる
    if "fc.weight" in state_dict:
        final_classes = state_dict["fc.weight"].shape[0]
    else:
        # fallback
        final_classes = classes_at_task(phase_id, args, data_manager.get_total_classnum())

    # すでに update_fc 済みのモデルなら二重に増やさない
    existing = len(getattr(net, "convnets", [])) if hasattr(net, "convnets") else 0

    # DER/TagFex 系は convnets を持つ。ここが最重要分岐。
    if hasattr(net, "update_fc") and hasattr(net, "convnets"):
        # 必要な回数だけ update_fc を呼ぶ
        # task k でのクラス数を与えながら進める（最後だけ final_classes で整合）
        total_cls = data_manager.get_total_classnum()
        for k in range(existing, convnet_count):
            nb = classes_at_task(k, args, total_cls)
            # 最後の update_fc は checkpoint の fc 次元に合わせる
            if k == convnet_count - 1:
                nb = final_classes
            net.update_fc(nb)
    else:
        # 拡張なしモデルは update_fc 1回で十分なことが多い
        if hasattr(net, "update_fc") and "fc.weight" in state_dict:
            net.update_fc(state_dict["fc.weight"].shape[0])

    return net

# --- DataManager & model ---
data_manager = DataManager(
    dataset_name=args["dataset"],
    shuffle=args.get("shuffle", True),
    seed=int(_seed_to_str(args.get("seed", 0))),
    init_cls=args["init_cls"],
    increment=args["increment"],
)

model = factory.get_model(args["model_name"], args)
model._device = args["device"][0]
model._network.to(model._device)

ckpt = torch.load(ckpt_path, map_location=model._device)
state_dict = ckpt["model_state_dict"]

# 1) skeleton を作る（update_fc を必要回数）
build_network_skeleton_for_ckpt(model, state_dict, data_manager, args, PHASE_ID)

# 2) state_dict ロード
model._network.load_state_dict(state_dict, strict=True)
model._network.to(model._device)

# 3) 追加情報（あれば）
if "forget_classes" in ckpt and hasattr(model, "forget_classes"):
    model.forget_classes = ckpt["forget_classes"]
if "protos" in ckpt:
    model._protos = ckpt["protos"]

model._network.eval()
print("loaded. convnets =", len(getattr(model._network, "convnets", [])))


Files already downloaded and verified
Files already downloaded and verified


  ckpt = torch.load(ckpt_path, map_location=model._device)


loaded. convnets = 6


## 4) seen / forget / retain クラス集合を作る（MU系なら forget_classes を活用）

In [18]:
# checkpoint phase 時点での「見えている」クラス数
num_classes = classes_at_task(PHASE_ID, args, data_manager.get_total_classnum())
all_seen = np.arange(num_classes)

forget_set = set(getattr(model, "forget_classes", []))
forget = np.array(sorted([c for c in forget_set if c < num_classes]), dtype=int)
retain = np.setdiff1d(all_seen, forget)

print("num_classes:", num_classes)
print("forget:", forget)
print("retain (head):", retain[:10], " ... total", len(retain))

num_classes: 60
forget: []
retain (head): [0 1 2 3 4 5 6 7 8 9]  ... total 60


## 5) Manual forget class 指定（要件2）と unlearn pruning 設定

- このセクション以降は **読み込んだ学習済みモデルに対して unlearning 処理だけ**を行います（要件1）。
- 忘却クラスはここで **手動で指定**します（要件2）。

In [19]:
# ===== Manual forget classes (EDIT HERE) =====
# 例: task5 に含まれる忘却クラス c を指定
FORGET_CLASSES = [50,51]   # <-- ここを手で変更

# ===== Target backbone selection =====
# 基本は「forget class を含むタスクを直接学習した backbone」を対象にする
TARGET_BACKBONE_ID = PHASE_ID  # task t の backbone を使うなら通常これでOK（必要なら手動で変更）

# ===== Unlearn pruning hyperparams =====
LAMBDA_RETAIN = 1.0      # S = I_forget - lambda * I_retain の lambda
PRUNE_RATIO   = 0.005    # 上位何割を壊すか（まずは小さめ推奨）
MAX_BATCHES_F = 500      # Fisher推定に使う最大バッチ数（計算節約）
MAX_BATCHES_R = 25000
BATCH_SIZE    = 256
NUM_WORKERS   = 4

# ===== Recompute seen/forget/retain sets (override) =====
num_classes = classes_at_task(PHASE_ID, args, data_manager.get_total_classnum())
all_seen = np.arange(num_classes)

forget = np.array(sorted([c for c in FORGET_CLASSES if c < num_classes]), dtype=int)
retain = np.setdiff1d(all_seen, forget)

print("num_classes:", num_classes)
print("FORGET_CLASSES:", forget.tolist())
print("retain classes:", len(retain))
print("TARGET_BACKBONE_ID:", TARGET_BACKBONE_ID)


num_classes: 60
FORGET_CLASSES: [50, 51]
retain classes: 58
TARGET_BACKBONE_ID: 5


In [20]:
# ===== Helpers: device / logits / backbone access =====
def _unwrap_net(net):
    # DataParallel / DDP を剥がす
    if hasattr(net, "module"):
        return net.module
    return net

def _get_device_from_model(model):
    dev = getattr(model, "_device", None)
    if isinstance(dev, (list, tuple)) and len(dev):
        return dev[0]
    if isinstance(dev, torch.device):
        return dev
    # fallback: network parameter device
    return next(model._network.parameters()).device

def _get_logits_from_network_output(out):
    # DER系は dict を返すことが多い: {"logits": ...}
    if isinstance(out, dict):
        if "logits" in out:
            return out["logits"]
        if "logit" in out:
            return out["logit"]
    # それ以外はテンソルそのものだとみなす
    return out

@torch.no_grad()
def eval_accuracy(model, loader, allowed_classes=None):
    model._network.eval()
    dev = _get_device_from_model(model)
    correct, total = 0, 0
    for _, x, y in loader:
        x = x.to(dev, non_blocking=True)
        y = y.to(dev, non_blocking=True)
        out = model._network(x)
        logits = _get_logits_from_network_output(out)
        if allowed_classes is not None:
            logits = logits[:, allowed_classes]
            pred = logits.argmax(dim=1)
            pred = torch.as_tensor(allowed_classes, device=dev)[pred]
        else:
            pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

def get_target_backbone(net, bb_id: int):
    net = _unwrap_net(net)
    # DERNet: net.convnets[bb_id]
    if hasattr(net, "convnets"):
        return net.convnets[bb_id]
    # TagFex系など: net.backbones / net.nets 等の可能性
    for attr in ["backbones", "nets", "models", "encoders"]:
        if hasattr(net, attr):
            obj = getattr(net, attr)
            try:
                return obj[bb_id]
            except Exception:
                pass
    raise AttributeError("backbone list (e.g., net.convnets) not found. Please check your network structure.")

def build_class_loader(data_manager, class_indices, source="train", mode="train",
                       batch_size=128, num_workers=4, shuffle=True):
    idx = class_indices.tolist() if isinstance(class_indices, np.ndarray) else list(class_indices)
    # PyCIL DataManager は実装差があるので、keyword/positional両対応にする
    try:
        ds = data_manager.get_dataset(indices=idx, source=source, mode=mode)
    except TypeError:
        ds = data_manager.get_dataset(idx, source, mode)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle,
                      num_workers=num_workers, pin_memory=True)

print("helpers ready.")


helpers ready.


In [21]:
# ===== Unlearn pruning core =====
def fisher_diag_on_loader(model, module, loader, max_batches=50):
    """
    module: 対象backbone（ここだけをスコアリングする）
    返り値: dict[name] = fisher_diag (CPU tensor)
    """
    dev = _get_device_from_model(model)
    model._network.eval()
    # 勾配を取るので no_grad は使わない
    fisher = {}
    for n, p in module.named_parameters():
        # print(n)
        if p.requires_grad:
            fisher[n] = torch.zeros_like(p, device="cpu")
            print("fisher[n].shape: ", fisher[n].shape)

    nb = 0
    for _, x, y in loader:
        x = x.to(dev, non_blocking=True)
        y = y.to(dev, non_blocking=True)

        # 既存gradクリア
        model._network.zero_grad(set_to_none=True)

        out = model._network(x)
        logits = _get_logits_from_network_output(out)

        loss = F.cross_entropy(logits, y)
        loss.backward()

        for n, p in module.named_parameters():
            if (not p.requires_grad) or (p.grad is None):
                continue
            fisher[n] += (p.grad.detach().cpu() ** 2)

        nb += 1
        if max_batches is not None and nb >= max_batches:
            break

    for n in fisher:
        fisher[n] /= max(nb, 1)
    return fisher

def compute_unlearn_score(fisher_forget, fisher_retain, lambda_retain=1.0):
    # S = I_forget - lambda * I_retain （CPU tensor dict）
    score = {}
    for n in fisher_forget:
        if n in fisher_retain:
            score[n] = fisher_forget[n] - lambda_retain * fisher_retain[n]
        else:
            score[n] = fisher_forget[n].clone()
    return score

def apply_pruning_by_score(module, score, prune_ratio=0.01, only_ndim_ge2=True):
    """
    score: dict[name] -> CPU tensor (same shape as param)
    prune_ratio: 上位何割の score（正の部分）をゼロ化
    only_ndim_ge2: conv/linear weight を主に対象にする（BNの1Dは除外）
    返り値: mask_dict, stats
    """
    # まず対象スコアを1本に集約
    all_scores = []
    meta = []
    for n, p in module.named_parameters():
        if n not in score:
            continue
        if only_ndim_ge2 and p.ndim < 2:
            continue
        s = score[n]
        s_pos = torch.clamp(s, min=0.0)  # forget重要（正）だけを対象にする
        all_scores.append(s_pos.flatten())
        meta.append(n)

    if len(all_scores) == 0:
        raise RuntimeError("No parameters selected for pruning. (Check only_ndim_ge2 or module structure.)")

    all_scores_cat = torch.cat(all_scores)
    k = int(prune_ratio * all_scores_cat.numel())
    if k <= 0:
        raise RuntimeError(f"prune_ratio too small: {prune_ratio} (k=0)")

    # threshold: 上位k個に入る最小値
    topk = torch.topk(all_scores_cat, k=k, largest=True)
    thr = topk.values.min().item()

    mask_dict = {}
    total_elems = 0
    pruned_elems = 0

    for n, p in module.named_parameters():
        if n not in score:
            continue
        if only_ndim_ge2 and p.ndim < 2:
            continue
        s_pos = torch.clamp(score[n], min=0.0)
        m = (s_pos >= thr)  # CPU bool
        # apply
        m_dev = m.to(p.device)
        with torch.no_grad():
            p.data[m_dev] = 0.0

        mask_dict[n] = m  # keep on CPU
        total_elems += m.numel()
        pruned_elems += int(m.sum().item())

    stats = {
        "threshold": thr,
        "total_elems": total_elems,
        "pruned_elems": pruned_elems,
        "pruned_ratio_actual": pruned_elems / max(total_elems, 1),
    }
    return mask_dict, stats

print("unlearn pruning core ready.")


unlearn pruning core ready.


In [22]:
# ===== 6) Build forget/retain loaders (train split) =====
# 注意: unlearning のため、通常は train split を使う（test を使う場合はここを変える）

forget_loader = build_class_loader(
    data_manager, forget, source="train", mode="test",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

# retain は全部だと重いので、必要なら一部クラスだけに絞る／サンプル数を制限する
retain_loader = build_class_loader(
    data_manager, retain, source="train", mode="train",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

# ===== 7) Evaluate BEFORE pruning =====
acc_forget_before = eval_accuracy(model, forget_loader)
acc_retain_before = eval_accuracy(model, retain_loader)

print(f"[BEFORE] acc_forget={acc_forget_before:.4f}  acc_retain={acc_retain_before:.4f}")

# ===== 8) Compute Fisher (forget vs retain) on TARGET_BACKBONE =====
target_bb = get_target_backbone(model._network, TARGET_BACKBONE_ID)

# Fisher 推定のために grad を有効化（freezeされていてもここだけスコアリング/編集したい）
for p in target_bb.parameters():
    p.requires_grad_(True)

f_f = fisher_diag_on_loader(model, target_bb, forget_loader, max_batches=MAX_BATCHES_F)
f_r = fisher_diag_on_loader(model, target_bb, retain_loader, max_batches=MAX_BATCHES_R)

score = compute_unlearn_score(f_f, f_r, lambda_retain=LAMBDA_RETAIN)

# ===== 9) Apply pruning =====
mask_dict, prune_stats = apply_pruning_by_score(
    target_bb, score, prune_ratio=PRUNE_RATIO, only_ndim_ge2=True
)

print("[PRUNE]", prune_stats)

# ===== 10) Evaluate AFTER pruning =====
acc_forget_after = eval_accuracy(model, forget_loader)
acc_retain_after = eval_accuracy(model, retain_loader)

print(f"[AFTER ] acc_forget={acc_forget_after:.4f}  acc_retain={acc_retain_after:.4f}")

# ===== Optional: save the pruned model checkpoint (weights only) =====
SAVE_PRUNED = True
if SAVE_PRUNED:
    out_path = Path(str(ckpt_path)).with_suffix("").as_posix() + f"_unlearn_pruned{FORGET_CLASSES}.pth"
    torch.save({
        "state_dict": _unwrap_net(model._network).state_dict(),
        "forget_classes": forget.tolist(),
        "phase_id": PHASE_ID,
        "target_backbone_id": TARGET_BACKBONE_ID,
        "prune_stats": prune_stats
    }, out_path)
    print("saved:", out_path)


[BEFORE] acc_forget=0.9680  acc_retain=0.8749
fisher[n].shape:  torch.Size([16, 3, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 16, 3, 3])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16])
fisher[n].shape:  torch.Size([16, 1

In [77]:
mask_dict.keys()
mask_dict["conv_1_3x3.weight"].shape
mask_dict["conv_1_3x3.weight"]

tensor([[[[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]]],


        [[[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]]],


        [[[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]]],


        [[[False, False, False],
          [False, False, False],
          [False, F