# DER系アプローチの分析用ノートブック

In [2]:
from pathlib import Path
import os, sys, json, re
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
# 例: 物理GPU1番だけを見せる
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

## 1) Project root / config / phase 設定

In [4]:
# === プロジェクトルートの指定 ===
# Notebook を CIDL-main の直下で開いているなら "." で OK
PROJECT_ROOT = Path("/home/kouyou/ContinualLearning/repexp/PyCIL").resolve()
os.chdir(PROJECT_ROOT)
sys.path.append(str(PROJECT_ROOT))

from trainer import _set_device  # 既存のヘルパをそのまま使う
from utils.data_manager import DataManager
from utils import factory

# === 使いたい設定ファイルと、どの phase を可視化するか ===
CONFIG_PATH = "exps/der_mu/baseline0/cifar100.json"   # 適宜変更
PHASE_ID    = 5                                         # ex) タスク3終了時のモデル → phase3.pkl

# === json 読み込み → args にする ===
with open(CONFIG_PATH) as f:
    args = json.load(f)

# device を training と同じ形式 (list of torch.device) に変換
_set_device(args)

print("model_name:", args["model_name"])
print("dataset   :", args["dataset"])
print("device    :", args["device"])

model_name: der-mu
dataset   : cifar100
device    : [device(type='cuda', index=0)]


## 2) Checkpoint path を trainer と同じ規則で組み立てる

In [9]:
from glob import glob

def _seed_to_str(seed):
    # trainer.py は args['seed'] をそのまま format しているので、list/int 両対応にしておく
    if isinstance(seed, (list, tuple)):
        return str(seed[0]) if len(seed) else "0"
    return str(seed)

def build_ckpt_dir(args, root=PROJECT_ROOT / "checkpoint"):
    
    """trainer / BaseLearner.save_checkpoint と同じ規則で checkpoint ディレクトリを作る"""
    init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
    log = "baseline" if "log" not in args else args["log"]
    
    ckpt_dir = "logs/{}/{}/{}/{}/{}/{}_{}_{}/".format(
        args["model_name"],
        log,
        args["dataset"],
        init_cls,
        args["increment"],
        args["prefix"], args["seed"][0], args["convnet_type"],
    )
    return ckpt_dir

ckpt_dir  = build_ckpt_dir(args)
ckpt_path = f"{ckpt_dir}/phase_{PHASE_ID}.pkl"

print("ckpt_dir :", ckpt_dir)
print("ckpt_path:", ckpt_path)

# # もし exists が False のときは、pattern で探す fallback も書いておくと楽
# if not ckpt_path.exists():
#     cand = glob(str(ckpt_dir / f"phase{PHASE_ID}*.pkl"))
#     print("fallback candidates:", cand)


ckpt_dir : logs/der-mu/baseline0/cifar100/0/10/reproduce_1993_resnet32/
ckpt_path: logs/der-mu/baseline0/cifar100/0/10/reproduce_1993_resnet32//phase_5.pkl


## 3) DataManager と model を作って checkpoint をロード（DER/TagFex など拡張系対応）

In [49]:
def infer_convnet_count_from_state_dict(state_dict):
    # keys like: 'convnets.0.xxx', 'convnets.1.xxx', ...
    idxs = []
    pat = re.compile(r"^convnets\.(\d+)\.")
    for k in state_dict.keys():
        m = pat.match(k)
        if m:
            idxs.append(int(m.group(1)))
    return (max(idxs) + 1) if idxs else 0

def classes_at_task(k, args, total_classnum):
    # task k の終了時点での total_classes（BaseLearnerの挙動に合わせる）
    init_cls = args["init_cls"]
    inc = args["increment"]
    num = init_cls + k * inc
    return min(num, total_classnum)

def build_network_skeleton_for_ckpt(model, state_dict, data_manager, args, phase_id):
    """state_dict をロード可能な形に network を拡張しておく。
    DERNet/TagFexNet 系は update_fc を task 回数ぶん呼ぶ必要がある。
    """
    net = model._network
    # DataParallel だと update_fc が面倒なので notebook では使わない前提
    if isinstance(net, torch.nn.DataParallel):
        net = net.module
        model._network = net

    convnet_count = infer_convnet_count_from_state_dict(state_dict)
    if convnet_count == 0:
        # ふつうはありえないが、念のため
        convnet_count = phase_id + 1

    # 最終クラス数は fc.weight の out_features で確定できる
    if "fc.weight" in state_dict:
        final_classes = state_dict["fc.weight"].shape[0]
    else:
        # fallback
        final_classes = classes_at_task(phase_id, args, data_manager.get_total_classnum())

    # すでに update_fc 済みのモデルなら二重に増やさない
    existing = len(getattr(net, "convnets", [])) if hasattr(net, "convnets") else 0

    # DER/TagFex 系は convnets を持つ。ここが最重要分岐。
    if hasattr(net, "update_fc") and hasattr(net, "convnets"):
        # 必要な回数だけ update_fc を呼ぶ
        # task k でのクラス数を与えながら進める（最後だけ final_classes で整合）
        total_cls = data_manager.get_total_classnum()
        for k in range(existing, convnet_count):
            nb = classes_at_task(k, args, total_cls)
            # 最後の update_fc は checkpoint の fc 次元に合わせる
            if k == convnet_count - 1:
                nb = final_classes
            net.update_fc(nb)
    else:
        # 拡張なしモデルは update_fc 1回で十分なことが多い
        if hasattr(net, "update_fc") and "fc.weight" in state_dict:
            net.update_fc(state_dict["fc.weight"].shape[0])

    return net

# --- DataManager & model ---
data_manager = DataManager(
    dataset_name=args["dataset"],
    shuffle=args.get("shuffle", True),
    seed=int(_seed_to_str(args.get("seed", 0))),
    init_cls=args["init_cls"],
    increment=args["increment"],
)

model = factory.get_model(args["model_name"], args)
model._device = args["device"][0]
model._network.to(model._device)

ckpt = torch.load(ckpt_path, map_location=model._device)
state_dict = ckpt["model_state_dict"]

# 1) skeleton を作る（update_fc を必要回数）
build_network_skeleton_for_ckpt(model, state_dict, data_manager, args, PHASE_ID)

# 2) state_dict ロード
model._network.load_state_dict(state_dict, strict=True)
model._network.to(model._device)

# 3) 追加情報（あれば）
if "forget_classes" in ckpt and hasattr(model, "forget_classes"):
    model.forget_classes = ckpt["forget_classes"]
if "protos" in ckpt:
    model._protos = ckpt["protos"]

model._network.eval()
print("loaded. convnets =", len(getattr(model._network, "convnets", [])))


Files already downloaded and verified
Files already downloaded and verified


  ckpt = torch.load(ckpt_path, map_location=model._device)


loaded. convnets = 6


## 4) seen / forget / retain クラス集合を作る（MU系なら forget_classes を活用）

In [50]:
# checkpoint phase 時点での「見えている」クラス数
num_classes = classes_at_task(PHASE_ID, args, data_manager.get_total_classnum())
all_seen = np.arange(num_classes)

forget_set = set(getattr(model, "forget_classes", []))
forget = np.array(sorted([c for c in forget_set if c < num_classes]), dtype=int)
retain = np.setdiff1d(all_seen, forget)

print("num_classes:", num_classes)
print("forget:", forget)
print("retain (head):", retain[:10], " ... total", len(retain))

num_classes: 60
forget: []
retain (head): [0 1 2 3 4 5 6 7 8 9]  ... total 60


## 5) Manual forget class 指定（要件2）と unlearn pruning 設定

- このセクション以降は **読み込んだ学習済みモデルに対して unlearning 処理だけ**を行います（要件1）。
- 忘却クラスはここで **手動で指定**します（要件2）。

In [51]:
# ===== Manual forget classes (EDIT HERE) =====
# 例: task5 に含まれる忘却クラス c を指定
FORGET_CLASSES = [50,51]   # <-- ここを手で変更

# ===== Target backbone selection =====
# 基本は「forget class を含むタスクを直接学習した backbone」を対象にする
TARGET_BACKBONE_ID = PHASE_ID  # task t の backbone を使うなら通常これでOK（必要なら手動で変更）

# ===== Unlearn pruning hyperparams =====
LAMBDA_RETAIN = 1.0      # S = I_forget - lambda * I_retain の lambda
PRUNE_RATIO   = 0.005    # 上位何割を壊すか（まずは小さめ推奨）
MAX_BATCHES_F = 500      # Fisher推定に使う最大バッチ数（計算節約）
MAX_BATCHES_R = 25000
BATCH_SIZE    = 256
NUM_WORKERS   = 4


# ===== Finetune data subsampling (NEW) =====
# finetuningで使う画像を「1クラスあたり N 枚」にランダムサンプリングする（0/Noneなら全て）
N_PER_CLASS_FT = 40
FT_SAMPLE_SEED = 0

# ===== Recompute seen/forget/retain sets (override) =====
num_classes = classes_at_task(PHASE_ID, args, data_manager.get_total_classnum())
all_seen = np.arange(num_classes)

forget = np.array(sorted([c for c in FORGET_CLASSES if c < num_classes]), dtype=int)
retain = np.setdiff1d(all_seen, forget)

print("num_classes:", num_classes)
print("FORGET_CLASSES:", forget.tolist())
print("retain classes:", len(retain))
print("TARGET_BACKBONE_ID:", TARGET_BACKBONE_ID)


# ===== Masked finetune hyperparams (NEW) =====
FT_STEPS =450           # finetune step 数（まずは 100〜500 くらい）
FT_LR    = 1e-4          # 小さめ推奨
FT_ALPHA_RETAIN = 1.0    # retain CE
FT_BETA_KD      = 1.0    # retain KD（unlearn前の挙動を保つ）
FT_GAMMA_FORGET = 1.0    # forget CE を上げる強さ
FT_TEMP         = 2.0    # KD temperature
FT_LOG_EVERY    = 50


num_classes: 60
FORGET_CLASSES: [50, 51]
retain classes: 58
TARGET_BACKBONE_ID: 5


In [52]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

# ===== Helpers: device / logits / backbone access =====
def _unwrap_net(net):
    # DataParallel / DDP を剥がす
    if hasattr(net, "module"):
        return net.module
    return net

def _get_device_from_model(model):
    dev = getattr(model, "_device", None)
    if isinstance(dev, (list, tuple)) and len(dev):
        return dev[0]
    if isinstance(dev, torch.device):
        return dev
    return next(model._network.parameters()).device

def _get_logits_from_network_output(out):
    # PyCIL系のネットワークは dict を返すことが多い
    if isinstance(out, dict):
        for k in ["logits", "logit", "outputs", "output"]:
            if k in out:
                return out[k]
        # fc が dict を返す場合
        if "fc" in out and torch.is_tensor(out["fc"]):
            return out["fc"]
        raise KeyError(f"Cannot find logits key in out.keys()={list(out.keys())}")
    # tuple/list の場合（(logits, features)など）
    if isinstance(out, (tuple, list)) and len(out) > 0:
        if torch.is_tensor(out[0]):
            return out[0]
    # それ以外は logits そのものとみなす
    return out

def _get_weight_2d(m):
    # nn.Linear 以外（SimpleLinear/CosineLinear等）でも weight を拾う
    w = getattr(m, "weight", None)
    if torch.is_tensor(w) and w.ndim == 2:
        return w
    return None

def get_target_backbone(net, bb_id: int):
    net = _unwrap_net(net)
    # DERNet: net.convnets[bb_id]
    if hasattr(net, "convnets"):
        return net.convnets[bb_id]
    # TagFex系など: net.backbones / net.nets 等の可能性
    for attr in ["backbones", "nets", "models", "encoders"]:
        if hasattr(net, attr):
            obj = getattr(net, attr)
            try:
                return obj[bb_id]
            except Exception:
                pass
    raise AttributeError("backbone list (e.g., net.convnets) not found.")

def _extract_targets_from_dataset(dataset):
    """Datasetからラベル配列を取得（できるだけ高速に属性から取る）。返り値は np.ndarray(int64)"""
    cand_attrs = ["targets", "labels", "y", "_y", "_targets", "_labels"]
    for a in cand_attrs:
        if hasattr(dataset, a):
            t = getattr(dataset, a)
            try:
                import torch
                if torch.is_tensor(t):
                    t = t.detach().cpu().numpy()
            except Exception:
                pass
            t = np.asarray(t)
            if t.ndim == 1 and len(t) == len(dataset):
                return t.astype(np.int64)
    # Fallback: iterate (遅い可能性あり)
    ys = []
    for i in range(len(dataset)):
        item = dataset[i]
        y = item[-1]
        try:
            import torch
            if torch.is_tensor(y):
                y = int(y.item())
        except Exception:
            pass
        ys.append(int(y))
    return np.asarray(ys, dtype=np.int64)

def _sample_indices_per_class(dataset, n_per_class, seed=0):
    """dataset内の各クラスから最大n_per_class枚ずつサンプリングした index list を返す。"""
    if n_per_class is None or int(n_per_class) <= 0:
        return None
    n_per_class = int(n_per_class)
    y = _extract_targets_from_dataset(dataset)
    classes = np.unique(y)
    rng = np.random.default_rng(seed)
    picked = []
    for c in classes:
        idx = np.where(y == c)[0]
        if len(idx) == 0:
            continue
        if len(idx) <= n_per_class:
            chosen = idx
        else:
            chosen = rng.choice(idx, size=n_per_class, replace=False)
        picked.extend(chosen.tolist())
    rng.shuffle(picked)
    return picked

def build_class_loader(data_manager, class_indices, source="train", mode="train",
                       batch_size=128, num_workers=4, shuffle=True,
                       n_per_class=None, seed=0):
    """DataManager.get_datasetでクラス選択した後、必要なら「1クラスあたりn_per_class枚」にサブサンプルしてDataLoaderを返す。"""
    if isinstance(class_indices, np.ndarray):
        class_indices = class_indices.tolist()
    dataset = data_manager.get_dataset(class_indices, source=source, mode=mode)

    sub_idx = _sample_indices_per_class(dataset, n_per_class=n_per_class, seed=seed)
    if sub_idx is not None:
        dataset = Subset(dataset, sub_idx)

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                      num_workers=num_workers, pin_memory=True)

@torch.no_grad()
def eval_accuracy(model, loader):
    dev = _get_device_from_model(model)
    model._network.eval()
    correct = 0
    total = 0
    for _, x, y in loader:
        x = x.to(dev, non_blocking=True)
        y = y.to(dev, non_blocking=True)
        out = model._network(x)
        logits = _get_logits_from_network_output(out)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)


In [53]:
import copy
import torch
import torch.nn.functional as F

# ===== Unlearn pruning core (score + mask) =====
def fisher_diag_on_loader(model, module, loader, max_batches=50):
    """module（対象backbone）の各パラメータについて E[(dL/dθ)^2] を推定する（CPUに蓄積）。"""
    dev = _get_device_from_model(model)
    model._network.eval()  # BN/Dropoutを固定してスコアを安定化

    fisher = {}
    for n, p in module.named_parameters():
        if p.requires_grad:
            fisher[n] = torch.zeros_like(p, device="cpu")

    nb = 0
    for _, x, y in loader:
        x = x.to(dev, non_blocking=True)
        y = y.to(dev, non_blocking=True)

        model._network.zero_grad(set_to_none=True)
        out = model._network(x)
        logits = _get_logits_from_network_output(out)
        loss = F.cross_entropy(logits, y)
        loss.backward()

        for n, p in module.named_parameters():
            if (not p.requires_grad) or (p.grad is None) or (n not in fisher):
                continue
            fisher[n] += (p.grad.detach().cpu() ** 2)

        nb += 1
        if max_batches is not None and nb >= max_batches:
            break

    for n in fisher:
        fisher[n] /= max(nb, 1)
    return fisher

def compute_unlearn_score(fisher_forget, fisher_retain, lambda_retain=1.0):
    # S = I_forget - lambda * I_retain （CPU tensor dict）
    score = {}
    for n in fisher_forget:
        if n in fisher_retain:
            score[n] = fisher_forget[n] - lambda_retain * fisher_retain[n]
        else:
            score[n] = fisher_forget[n].clone()
    return score

def build_mask_by_score(module, score, prune_ratio=0.01, only_ndim_ge2=True):
    """score の正の部分を上位 prune_ratio だけ選び、編集対象マスクを返す（weightは変更しない）。"""
    all_scores = []
    for n, p in module.named_parameters():
        if n not in score:
            continue
        if only_ndim_ge2 and p.ndim < 2:
            continue
        s_pos = torch.clamp(score[n], min=0.0)
        all_scores.append(s_pos.flatten())

    if len(all_scores) == 0:
        raise RuntimeError("No parameters selected for mask. (Check only_ndim_ge2 or module structure.)")

    all_scores_cat = torch.cat(all_scores)
    k = int(prune_ratio * all_scores_cat.numel())
    if k <= 0:
        raise RuntimeError(f"prune_ratio too small: {prune_ratio} (k=0)")

    topk = torch.topk(all_scores_cat, k=k, largest=True)
    thr = topk.values.min().item()

    mask_dict = {}
    total_elems = 0
    masked_elems = 0

    for n, p in module.named_parameters():
        if n not in score:
            continue
        if only_ndim_ge2 and p.ndim < 2:
            continue
        s_pos = torch.clamp(score[n], min=0.0)
        m = (s_pos >= thr)  # CPU bool (True=編集対象)
        mask_dict[n] = m
        total_elems += m.numel()
        masked_elems += int(m.sum().item())

    stats = {
        "threshold": thr,
        "total_elems": total_elems,
        "masked_elems": masked_elems,
        "masked_ratio_actual": masked_elems / max(total_elems, 1),
    }
    return mask_dict, stats

def _register_grad_mask(module, mask_dict):
    """編集対象(True)以外の勾配を0にする hook を登録する。"""
    handles = []
    for n, p in module.named_parameters():
        if n not in mask_dict:
            continue
        m = mask_dict[n].to(p.device)
        # float mask
        m_f = m.to(dtype=p.dtype)

        def hook(grad, m_f=m_f):
            return grad * m_f

        handles.append(p.register_hook(hook))
    return handles

def masked_unlearn_finetune(model, target_bb, mask_dict,
                            forget_loader, retain_loader=None,
                            steps=200, lr=1e-4, wd=0.0,
                            alpha_retain=1.0, beta_kd=1.0, gamma_forget=1.0,
                            T=2.0, log_every=50):
    """mask_dict(True=更新可)の箇所以外を固定したまま、forgetを崩し retain を縛る finetuning。
    - forget: loss ascent（CEを大きくする）
    - retain: CE + distillation（unlearn前の挙動を保つ）
    """
    dev = _get_device_from_model(model)

    # 参照モデル（unlearn前）を固定して保存
    ref_net = copy.deepcopy(_unwrap_net(model._network)).to(dev).eval()

    # optimizer は target_bb の全パラメータを渡す（grad mask で更新箇所だけ通す）
    params = [p for p in target_bb.parameters() if p.requires_grad]
    opt = torch.optim.Adam(params, lr=lr, weight_decay=wd)

    # 勾配マスク hook
    handles = _register_grad_mask(target_bb, mask_dict)

    # 固定する部分を“完全に固定”するため、初期値を保存（mask外を毎stepで復元）
    frozen_snap = {}
    for n, p in target_bb.named_parameters():
        if n in mask_dict:
            m = mask_dict[n].to(p.device)
            with torch.no_grad():
                frozen_snap[n] = p.detach().clone()  # 全体
        else:
            frozen_snap[n] = p.detach().clone()

    # iterator
    it_f = iter(forget_loader)
    it_r = iter(retain_loader) if retain_loader is not None else None

    model._network.eval()  # BN統計を動かさない（編集範囲を小さくする）

    for step in range(1, steps + 1):
        try:
            _, xf, yf = next(it_f)
        except StopIteration:
            it_f = iter(forget_loader)
            _, xf, yf = next(it_f)

        xf = xf.to(dev, non_blocking=True)
        yf = yf.to(dev, non_blocking=True)

        opt.zero_grad(set_to_none=True)

        # forget: CE を上げる（= -CE を最小化）
        out_f = model._network(xf)
        logits_f = _get_logits_from_network_output(out_f)
        loss_forget = F.cross_entropy(logits_f, yf)

        total_loss = -gamma_forget * loss_forget

        # retain: CE + KD
        if retain_loader is not None:
            try:
                _, xr, yr = next(it_r)
            except StopIteration:
                it_r = iter(retain_loader)
                _, xr, yr = next(it_r)

            xr = xr.to(dev, non_blocking=True)
            yr = yr.to(dev, non_blocking=True)

            out_r = model._network(xr)
            logits_r = _get_logits_from_network_output(out_r)

            with torch.no_grad():
                ref_out_r = ref_net(xr)
                ref_logits_r = _get_logits_from_network_output(ref_out_r)

            loss_retain_ce = F.cross_entropy(logits_r, yr)

            # KL( current || ref ) で保持
            logp = F.log_softmax(logits_r / T, dim=1)
            q = F.softmax(ref_logits_r / T, dim=1)
            loss_kd = F.kl_div(logp, q, reduction="batchmean") * (T * T)

            total_loss = total_loss + alpha_retain * loss_retain_ce + beta_kd * loss_kd

        total_loss.backward()
        opt.step()

        # mask外を復元して完全固定（weight decay 等の影響も排除）
        with torch.no_grad():
            for n, p in target_bb.named_parameters():
                if n not in mask_dict:
                    p.copy_(frozen_snap[n])
                    continue
                m = mask_dict[n].to(p.device)
                snap = frozen_snap[n]
                # True が更新可なので、False 部分を戻す
                p.data[~m] = snap.data[~m]

        if (step % log_every) == 0 or step == 1:
            msg = f"[ft step {step:4d}/{steps}] total={float(total_loss.item()):.4f}  forgetCE={float(loss_forget.item()):.4f}"
            if retain_loader is not None:
                msg += f"  retainCE={float(loss_retain_ce.item()):.4f}  KD={float(loss_kd.item()):.4f}"
            print(msg)

    # hook cleanup
    for h in handles:
        h.remove()


In [54]:
# ===== 6) Build forget/retain loaders (train split) =====
# 注意: unlearning のため、通常は train split を使う（test を使う場合はここを変える）

forget_loader = build_class_loader(
    data_manager, forget, source="train", mode="train",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

# retain は全部だと重いので、必要なら一部クラスだけに絞る／サンプル数を制限する
retain_loader = build_class_loader(
    data_manager, retain, source="train", mode="train",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)



# ---- Finetune loaders: sample N images per class ----
# NOTE: 評価・Fisher推定は上の full loader を使い、finetune だけサブサンプルを使う
forget_loader_ft = build_class_loader(
    data_manager, forget, source="train", mode="train",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True,
    n_per_class=N_PER_CLASS_FT, seed=FT_SAMPLE_SEED
)
retain_loader_ft = build_class_loader(
    data_manager, retain, source="train", mode="train",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True,
    n_per_class=N_PER_CLASS_FT, seed=FT_SAMPLE_SEED
)

print(f"[FT subset] forget N={len(getattr(forget_loader_ft, 'dataset', []))}  retain N={len(getattr(retain_loader_ft, 'dataset', []))}")
# ===== 7) Evaluate BEFORE =====
acc_forget_before = eval_accuracy(model, forget_loader)
acc_retain_before = eval_accuracy(model, retain_loader)

print(f"[BEFORE] acc_forget={acc_forget_before:.4f}  acc_retain={acc_retain_before:.4f}")

# ===== 8) Compute Fisher (forget vs retain) on TARGET_BACKBONE =====
target_bb = get_target_backbone(model._network, TARGET_BACKBONE_ID)

# Fisher 推定のために grad を有効化（freezeされていてもここだけスコアリング/編集したい）
for p in target_bb.parameters():
    p.requires_grad_(True)

f_f = fisher_diag_on_loader(model, target_bb, forget_loader, max_batches=MAX_BATCHES_F)
f_r = fisher_diag_on_loader(model, target_bb, retain_loader, max_batches=MAX_BATCHES_R)

score = compute_unlearn_score(f_f, f_r, lambda_retain=LAMBDA_RETAIN)

# ===== 9) Build mask (do NOT zero weights) =====
mask_dict, mask_stats = build_mask_by_score(
    target_bb, score, prune_ratio=PRUNE_RATIO, only_ndim_ge2=True
)
print("[MASK]", mask_stats)

# ===== 10) Masked finetune (unlearning) =====
# 編集対象(True)以外は固定したまま、forgetを崩してretainを維持するように微調整する
masked_unlearn_finetune(
    model, target_bb, mask_dict,
    forget_loader=forget_loader_ft,
    retain_loader=retain_loader_ft,
    steps=FT_STEPS,
    lr=FT_LR,
    wd=0.0,
    alpha_retain=FT_ALPHA_RETAIN,
    beta_kd=FT_BETA_KD,
    gamma_forget=FT_GAMMA_FORGET,
    T=FT_TEMP,
    log_every=FT_LOG_EVERY
)

# ===== 11) Evaluate AFTER finetune =====
acc_forget_after = eval_accuracy(model, forget_loader)
acc_retain_after = eval_accuracy(model, retain_loader)

print(f"[AFTER ] acc_forget={acc_forget_after:.4f}  acc_retain={acc_retain_after:.4f}")

# ===== Optional: save the finetuned model checkpoint (weights only) =====
SAVE_FT = False
if SAVE_FT:
    out_path = Path(str(ckpt_path)).with_suffix("").as_posix() + f"_unlearn_maskft{FORGET_CLASSES}.pth"
    torch.save({
        "state_dict": _unwrap_net(model._network).state_dict(),
        "forget_classes": forget.tolist(),
        "phase_id": PHASE_ID,
        "target_backbone_id": TARGET_BACKBONE_ID,
        "mask_stats": mask_stats,
        "ft": {
            "steps": FT_STEPS, "lr": FT_LR,
            "alpha_retain": FT_ALPHA_RETAIN,
            "beta_kd": FT_BETA_KD,
            "gamma_forget": FT_GAMMA_FORGET,
            "temp": FT_TEMP
        }
    }, out_path)
    print("saved:", out_path)


[FT subset] forget N=80  retain N=2320
[BEFORE] acc_forget=0.9720  acc_retain=0.8766
[MASK] {'threshold': 0.00016308830527123064, 'total_elems': 461872, 'masked_elems': 2309, 'masked_ratio_actual': 0.004999220563272942}
[ft step    1/450] total=0.2009  forgetCE=0.1005  retainCE=0.3014  KD=-0.0000
[ft step   50/450] total=-0.3711  forgetCE=0.5348  retainCE=0.1615  KD=0.0021
[ft step  100/450] total=-0.9144  forgetCE=1.3225  retainCE=0.3852  KD=0.0229
[ft step  150/450] total=-1.9316  forgetCE=2.2694  retainCE=0.3200  KD=0.0179
[ft step  200/450] total=-1.2350  forgetCE=3.0123  retainCE=1.2755  KD=0.5017
[ft step  250/450] total=-2.8964  forgetCE=3.4367  retainCE=0.5103  KD=0.0301
[ft step  300/450] total=-3.6712  forgetCE=4.0355  retainCE=0.3384  KD=0.0260
[ft step  350/450] total=-3.8939  forgetCE=4.1912  retainCE=0.2044  KD=0.0929
[ft step  400/450] total=-4.3975  forgetCE=4.5261  retainCE=0.1015  KD=0.0272
[ft step  450/450] total=-4.2395  forgetCE=4.6962  retainCE=0.2849  KD=0.1718
