In [1]:
import pandas as pd
import numpy as np


In [4]:

from torch.backends.cuda import sdp_kernel

# enable math implementation
sdp_kernel.enable_math(True)
# disable flash and mem-efficient implementations (which may lack backward on this setup)
sdp_kernel.enable_flash(False)
sdp_kernel.enable_mem_efficient(False)

print("Configured torch SDPA: math=True, flash=False, mem_efficient=False")


AttributeError: 'function' object has no attribute 'enable_math'

In [2]:
a = np.array([1, 2, 3])

In [3]:
a

array([1, 2, 3])

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.init as init
import torch.backends.cudnn as cudnn
from torch.utils.data.dataset import Dataset
from torch.optim.lr_scheduler import MultiStepLR




import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data.distributed import DistributedSampler


from PIL import Image

import os
import sys
import time
import types
import numpy as np
import shutil
import warnings
import random
import math




from statistics import mean
import wandb
import h5py


In [6]:
from framework.config import get_config, get_arch, get_dataset, get_transform, get_pin_memory, DistillDataset

In [7]:
from framework.distill_higher import Distill
from framework.util import Summary, AverageMeter, ProgressMeter, accuracy, accuracy_ind, ImageIntervention, init_gaussian
import torch_optimizer

In [8]:
import argparse

from framework.base import main_worker

In [9]:
import scipy

In [10]:
from enum import Enum

In [1]:
import h5py
import torch

for f in [
    "out_step2_mrpc_emb_text_mlp_ipc05_s0.h5",
    "out_step2_mrpc_emb_text_mlp_ipc10_s1.h5"
    # "out_step2_mrpc_emb_text_mlp_ipc15_s2.h5"
]:
    print("\n====", f, "====")
    with h5py.File(f, "r") as h:
        print("data shape:", h["data"][:].shape)
        if "label" in h:
            print("label shape:", h["label"][:].shape)
            print("unique labels:", set(h["label"][:].tolist()))
        else:
            print("NO LABEL FOUND")



==== out_step2_mrpc_emb_text_mlp_ipc05_s0.h5 ====
data shape: (10, 768)
NO LABEL FOUND

==== out_step2_mrpc_emb_text_mlp_ipc10_s1.h5 ====
data shape: (20, 768)
NO LABEL FOUND


In [1]:
import h5py
import torch

def load_h5(path):
    with h5py.File(path, "r") as f:
        data = torch.tensor(f["data"][:]).float()
    return data

s0 = load_h5("out_step2_mrpc_emb_text_mlp_ipc05_s0.h5")   # (10,768)
s1 = load_h5("out_step2_mrpc_emb_text_mlp_ipc10_s1.h5")  # (20,768)

num_classes = 2
ipc0 = s0.shape[0] // num_classes   # 5
ipc1 = s1.shape[0] // num_classes   # 10
old_ipc = ipc0

print("ipc0:", ipc0, "ipc1:", ipc1, "old_ipc:", old_ipc)

# Compare old blocks per class
for c in range(num_classes):
    s0_block = s0[c*ipc0 : c*ipc0 + old_ipc]              # old class-c
    s1_old_block = s1[c*ipc1 : c*ipc1 + old_ipc]          # supposed old class-c in s1

    diff = (s0_block - s1_old_block).abs().max().item()
    mse = ((s0_block - s1_old_block)**2).mean().item()

    print(f"class {c}: max_abs_diff={diff:.6e}, mse={mse:.6e}")

    if diff < 1e-6:
        print("  ✅ old block preserved for this class")
    else:
        print("  ❌ old block NOT preserved for this class")


ipc0: 5 ipc1: 10 old_ipc: 5
class 0: max_abs_diff=0.000000e+00, mse=0.000000e+00
  ✅ old block preserved for this class
class 1: max_abs_diff=0.000000e+00, mse=0.000000e+00
  ✅ old block preserved for this class


In [2]:
import h5py
import torch
from typing import List, Tuple

def load_h5(path: str) -> torch.Tensor:
    with h5py.File(path, "r") as f:
        return torch.tensor(f["data"][:]).float()

def check_boostdd_preservation(
    h5_paths: List[str],
    num_classes: int,
    tol: float = 1e-6,
    verbose: bool = True
) -> List[Tuple[int, int, bool]]:
    """
    h5_paths: ordered list of stage outputs, e.g. [s0.h5, s1.h5, s2.h5, ...]
    num_classes: number of classes (2 for MRPC)
    tol: tolerance for preservation
    Returns list of (k, k+1, passed)
    """
    stages = [load_h5(p) for p in h5_paths]
    results = []

    # sanity: embedding dim must be constant
    dims = [s.shape[1] for s in stages]
    if len(set(dims)) != 1:
        raise ValueError(f"Embedding dims not consistent across stages: {dims}")

    for k in range(len(stages) - 1):
        s_old = stages[k]
        s_new = stages[k + 1]

        ipc_old = s_old.shape[0] // num_classes
        ipc_new = s_new.shape[0] // num_classes

        if verbose:
            print(f"\nStage {k} -> Stage {k+1}: ipc_old={ipc_old}, ipc_new={ipc_new}")

        passed_pair = True
        for c in range(num_classes):
            old_block = s_old[c*ipc_old : c*ipc_old + ipc_old]
            new_old_prefix = s_new[c*ipc_new : c*ipc_new + ipc_old]

            max_abs = (old_block - new_old_prefix).abs().max().item()
            mse = ((old_block - new_old_prefix) ** 2).mean().item()

            if verbose:
                print(f"  class {c}: max_abs_diff={max_abs:.6e}, mse={mse:.6e}")

            if max_abs >= tol:
                passed_pair = False

        if verbose:
            print("  ✅ preserved" if passed_pair else "  ❌ NOT preserved")

        results.append((k, k+1, passed_pair))

    return results


In [3]:
paths = [
    "out_step2_mrpc_emb_text_mlp_ipc05_s0.h5",
    "out_step2_mrpc_emb_text_mlp_ipc10_s1.h5",
    "out_step2_mrpc_emb_text_mlp_ipc15_s2.h5",
    # add more as you generate them
]

check_boostdd_preservation(paths, num_classes=2, tol=1e-6)



Stage 0 -> Stage 1: ipc_old=5, ipc_new=10
  class 0: max_abs_diff=0.000000e+00, mse=0.000000e+00
  class 1: max_abs_diff=0.000000e+00, mse=0.000000e+00
  ✅ preserved

Stage 1 -> Stage 2: ipc_old=10, ipc_new=15
  class 0: max_abs_diff=0.000000e+00, mse=0.000000e+00
  class 1: max_abs_diff=0.000000e+00, mse=0.000000e+00
  ✅ preserved


[(0, 1, True), (1, 2, True)]

In [5]:
def check_against_stage0(h5_paths: List[str], num_classes: int, tol: float=1e-6):
    s0 = load_h5(h5_paths[0])
    ipc0 = s0.shape[0] // num_classes

    for k, p in enumerate(h5_paths[1:], start=1):
        sk = load_h5(p)
        ipck = sk.shape[0] // num_classes
        print(f"\nStage 0 -> Stage {k}: ipc0={ipc0}, ipck={ipck}")

        passed = True
        for c in range(num_classes):
            s0_block = s0[c*ipc0 : c*ipc0 + ipc0]
            sk_prefix = sk[c*ipck : c*ipck + ipc0]
            max_abs = (s0_block - sk_prefix).abs().max().item()
            print(f"  class {c}: max_abs_diff={max_abs:.6e}")
            if max_abs >= tol:
                passed = False

        print("  ✅ preserved from stage0" if passed else "  ❌ NOT preserved from stage0")
paths = [
    "out_step2_mrpc_emb_text_mlp_ipc05_s0.h5",
    "out_step2_mrpc_emb_text_mlp_ipc10_s1.h5",
    "out_step2_mrpc_emb_text_mlp_ipc15_s2.h5",
    "out_step2_mrpc_emb_text_mlp_ipc20_s3.h5",
    "out_step2_mrpc_emb_text_mlp_ipc25_s4.h5"
]

check_against_stage0(paths, num_classes=2, tol=1e-6)


Stage 0 -> Stage 1: ipc0=5, ipck=10
  class 0: max_abs_diff=0.000000e+00
  class 1: max_abs_diff=0.000000e+00
  ✅ preserved from stage0

Stage 0 -> Stage 2: ipc0=5, ipck=15
  class 0: max_abs_diff=0.000000e+00
  class 1: max_abs_diff=0.000000e+00
  ✅ preserved from stage0

Stage 0 -> Stage 3: ipc0=5, ipck=20
  class 0: max_abs_diff=0.000000e+00
  class 1: max_abs_diff=0.000000e+00
  ✅ preserved from stage0

Stage 0 -> Stage 4: ipc0=5, ipck=25
  class 0: max_abs_diff=0.000000e+00
  class 1: max_abs_diff=0.000000e+00
  ✅ preserved from stage0


In [32]:
print(1)

1


In [28]:
print(1)

1


In [29]:
print(1)

1


In [30]:
print(1)

1


In [31]:
print(1)

1
