In [1]:
# Full pipeline: FD + class-balanced gradient agreement subset selection

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.models import resnet18
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from functorch import make_functional_with_buffers, vmap, grad
from collections import defaultdict
from torch.utils.data import Dataset
import os
import PIL.Image as Image
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
import torch
import torchvision.transforms as T
from timm.data import create_transform

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- data stats ----------
train_raw = CIFAR100('./data100', train=True,  download=True)
x = np.concatenate([np.asarray(train_raw[i][0]) for i in range(len(train_raw))])
mean = (x.mean((0, 1))/255).tolist()
std  = (x.std((0, 1))/255).tolist()

# ---------- strong augmentation ----------
# RandAugment parameters: N=2 ops, M=9 magnitude (common default)
rand_augment = T.RandAugment(num_ops=2, magnitude=9)

base_transforms = [
    T.RandomCrop(32, padding=4, padding_mode='reflect'),
    T.RandomHorizontalFlip(),
    rand_augment,                       # <-- NEW
    T.ToTensor(),
    T.Normalize(mean, std, inplace=True),
    T.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3),
                    value=mean),        # <-- NEW
]

transform_train = T.Compose(base_transforms)
transform_test  = T.Compose([T.ToTensor(), T.Normalize(mean, std)])

train_dataset = CIFAR100('./data100', train=True,
                         transform=transform_train, download=False)
test_dataset  = CIFAR100('./data100', train=False,
                         transform=transform_test,  download=False)


  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified


In [3]:
# class TinyImageNet(Dataset):
#     def __init__(self, root, split='train', transform=None):
#         self.root_dir = root
#         self.split = split
#         self.transform = transform
#         with open(os.path.join(self.root_dir, 'wnids.txt'), 'r') as f:
#             self.classes = f.read().strip().split()
#         self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
#         self.class_name = self._get_names()
#         self.images = self._load_images()

#     def _load_images(self):
#         images = []
#         if self.split == 'train':
#             for cls in self.classes:
#                 cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
#                 for image_file in os.listdir(cls_dir):
#                     image_path = os.path.join(cls_dir, image_file)
#                     images.append((image_path, self.class_to_idx[cls]))
#         elif self.split == 'val':
#             val_dir = os.path.join(self.root_dir, self.split, 'images')
#             image_to_cls = {}
#             with open(os.path.join(self.root_dir, self.split, 'val_annotations.txt'), 'r') as f:
#                 for line in f.read().strip().split('\n'):
#                     image_to_cls[line.split()[0].strip()] = line.split()[1].strip()
#             for image_file in os.listdir(val_dir):
#                 image_path = os.path.join(val_dir, image_file)
#                 images.append((image_path, self.class_to_idx[image_to_cls[image_file]]))
#         return images

#     def __len__(self):
#         return len(self.images)

#     def __getitem__(self, idx):
#         img_path, label = self.images[idx]
#         image = Image.open(img_path).convert('RGB')
#         if self.transform:
#             image = self.transform(image)
#         return image, label

#     def _get_names(self):
#         entity_dict = {}
#         with open(os.path.join(self.root_dir, 'words.txt'), 'r') as file:
#             for line in file:
#                 key, value = line.strip().split('\t')
#                 first = value.strip().split(',')
#                 entity_dict[key] = first[0]
#         return entity_dict

In [4]:
# import torchvision.transforms as T
# from timm.data.mixup import Mixup   # <─ pip install timm if you haven’t
# # --------------------------------------------
# dirs="./tiny-imagenet-200/"
# device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # colour statistics you already use
# mean = (0.480, 0.448, 0.397)
# std  = (0.276, 0.269, 0.282)

# train_tf = T.Compose([
#     # PIL-space augmentations
#     T.RandomResizedCrop(64, scale=(0.8, 1.0),
#                         interpolation=T.InterpolationMode.BICUBIC),
#     T.RandomHorizontalFlip(),
#     T.RandAugment(num_ops=2, magnitude=9),

#     # tensor-space ops
#     T.ToTensor(),
#     T.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=mean),
#     T.Normalize(mean, std, inplace=True),
# ])

# # ---------- validation / test ----------
# val_tf = T.Compose([
#     T.Resize(64, interpolation=T.InterpolationMode.BICUBIC),
#     T.CenterCrop(64),
#     T.ToTensor(),
#     T.Normalize(mean, std, inplace=True),
# ])

# # 4️⃣  datasets
# train_dataset = TinyImageNet(dirs, split='train', transform=train_tf)
# test_dataset   = TinyImageNet(dirs, split='val',   transform=val_tf)


In [5]:
# dirs="./imagenet/"
# traindir = os.path.join(dirs, 'train')
# valdir = os.path.join(dirs, 'val')

# device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# train_dataset = datasets.ImageFolder(
#     traindir,
#     transforms.Compose([
#         transforms.Resize(256),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])
    
# )

# test_dataset = datasets.ImageFolder(
#     valdir,
#     transforms.Compose([
#         transforms.Resize(256),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])
# )


In [6]:
# trainloader = torch.utils.data.DataLoader(fullset, batch_size=trn_batch_size,
#                                           shuffle=False, pin_memory=True, num_workers=1)

# valloader = torch.utils.data.DataLoader(testset, batch_size=val_batch_size,
#                                             shuffle=False, pin_memory=True, num_workers=1)


In [7]:
import numpy as np, torch
from tqdm import tqdm

@torch.no_grad()
def gpu_fd_stream(
        A: np.ndarray,
        ell: int,
        device=None,
        batch_size: int = 2048,
        dtype: torch.dtype = torch.float16,
        svd_dtype: torch.dtype = torch.float32,
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    m, n = A.shape
    B = torch.zeros((ell, n), device=device, dtype=dtype)
    next_row = 0

    for start in range(0, m, batch_size):
        batch_gpu = torch.from_numpy(A[start:start + batch_size])\
                 .to(device=device, dtype=dtype, non_blocking=True)


        insert = 0
        while insert < batch_gpu.shape[0]:
            space = ell - next_row
            take = min(space, batch_gpu.shape[0] - insert)
            B[next_row:next_row + take] = batch_gpu[insert:insert + take]
            next_row += take
            insert += take

            if next_row == ell:                       # overflow → compress
                U, s, Vt = torch.linalg.svd(
                    B.to(svd_dtype), full_matrices=False
                )
                delta = s[-1] ** 2
                s = torch.sqrt(torch.clamp(s ** 2 - delta, min=0.0))
                B = (torch.diag(s).to(dtype) @ Vt.to(dtype))
                next_row = (s > 1e-8).sum().item()
                if next_row < ell:
                    B[next_row:].zero_()
                torch.cuda.empty_cache()

    return B[:next_row].float().cpu().numpy()


In [8]:
# --------- helper -------------
class FDStreamer:
    def __init__(self, ell, batch_size=2048, dtype=torch.float16):
        self.ell = ell; self.batch_size = batch_size; self.dtype = dtype
        self._buf = []      # CPU mini-batches
        self._B   = None

    def add(self, grad_batch: np.ndarray):
        self._buf.append(grad_batch)
        if sum(b.shape[0] for b in self._buf) >= self.batch_size:
            self._flush()

    def _flush(self):
        if not self._buf: return
        A_cpu = np.vstack(self._buf)
        self._buf.clear()
        if self._B is None:
            self._B = gpu_fd_stream(A_cpu, self.ell,
                                    batch_size=self.batch_size,
                                    dtype=self.dtype)
        else:
            A_big = np.vstack([self._B, A_cpu])
            self._B = gpu_fd_stream(A_big, self.ell,
                                    batch_size=self.batch_size,
                                    dtype=self.dtype)

    def finalize(self):
        self._flush()
        return self._B
# ------------------------------


In [9]:
# def compute_per_sample_sketches(
#         model, loader, criterion, device, proj_matrix,
#         chunk_size: int = 4             # <-- new arg, default 4
# ):
#     model.eval()
#     fmodel, params, buffers = make_functional_with_buffers(model)

#     def compute_loss(p, b, x, y):
#         out = fmodel(p, b, x.unsqueeze(0)).squeeze(0)
#         return criterion(out, y)

#     grad_fn = grad(compute_loss)
#     grads_all = []

#     for x, y in loader:
#         x, y = x.to(device), y.to(device)

#         # OOM-safe vmap
#         grads_batch = vmap(
#             grad_fn, in_dims=(None, None, 0, 0),
#             chunk_size=chunk_size        # <-----------------
#         )(params, buffers, x, y)

#         flat = torch.cat(
#             [g.reshape(g.shape[0], -1) for g in grads_batch], dim=1
#         )
#         projected = flat @ proj_matrix.T        # stays on device
#         grads_all.append(projected.cpu())       # move to host, free VRAM
#         del grads_batch, flat, projected
#         torch.cuda.empty_cache()

#     return torch.cat(grads_all, dim=0)          # (N, ℓ) on CPU


In [10]:
# # --------------------- Agreement Selector ---------------------
# def select_agreeing_subset(grads: torch.Tensor, subset_size: int):
#     grads = grads / (grads.norm(dim=1, keepdim=True) + 1e-8)
#     sim_matrix = grads @ grads.T
#     agreement = sim_matrix.sum(dim=1)
#     top_indices = torch.topk(agreement, subset_size).indices
#     return top_indices.numpy()

In [11]:
# def select_agreeing_subset_fast(
#         grads: torch.Tensor,          # (N, ℓ)  on **any device**
#         subset_size: int
# ) -> np.ndarray:
#     """
#     Return indices of the `subset_size` samples whose (ℓ-dim) gradients
#     agree most with the centroid.  Memory / time:  O(N·ℓ).
#     """
#     grads = grads / (grads.norm(dim=1, keepdim=True) + 1e-8)   # L2-normalise
#     centroid = grads.mean(dim=0)                               # (ℓ,)
#     scores = grads @ centroid                                   # (N,)
#     top = torch.topk(scores, subset_size, largest=True).indices
#     return top.cpu().numpy()

With diversity Penalty

In [12]:

# from __future__ import annotations
# from collections import defaultdict
# from typing      import List

# import torch, tqdm
# from torch.utils.data import DataLoader
# import torch.nn.functional as F


# # --------------------------------------------------------------
# # utility: project the gradient of ONE sample to ℓ dims
# # --------------------------------------------------------------
# def _project_single_grad(
#         model, x, y,
#         criterion,
#         proj_matrix: torch.Tensor      # (ℓ, D)  — on same device as model
# ) -> torch.Tensor:                    # (ℓ,)
#     model.zero_grad(set_to_none=True)           # free old grads
#     out   = model(x.unsqueeze(0)).squeeze(0)    # keep batch dim
#     loss  = criterion(out, y)
#     loss.backward()

#     # flatten all parameter grads in registration order
#     g_proj = torch.zeros(proj_matrix.size(0), device=proj_matrix.device)
#     offset = 0
#     for p in model.parameters():
#         if p.grad is None:
#             continue
#         g_flat = p.grad.flatten().to(proj_matrix.dtype)          # (Pi,)
#         P_slice = proj_matrix[:, offset: offset + g_flat.numel()]  # (ℓ, Pi)
#         g_proj += P_slice @ g_flat                               # accumulate
#         offset += g_flat.numel()
#     return g_proj


# # --------------------------------------------------------------
# # main selector (OOM-proof, no functorch)
# # --------------------------------------------------------------
# def class_balanced_agreeing_subset_fast(
#         model,
#         dataset,
#         num_classes        : int,
#         samples_per_class  : int,
#         criterion,
#         device,
#         proj_matrix        : torch.Tensor,   # (ℓ, D)  on same device as model
#         batch_size_data    : int = 64,       # images per forward pass
#         chunk_size_grad    : int = 4         # #images whose grads we keep at once
# ) -> List[int]:
#     """
#     Pick `samples_per_class` images per class with the highest
#     gradient-agreement score, **without ever running out of GPU memory**.
#     Returns a list of dataset indices.
#     """
#     model.eval()
#     proj_matrix = proj_matrix.to(device)

#     loader = DataLoader(dataset,
#                         batch_size=batch_size_data,
#                         shuffle=False,
#                         num_workers=4,
#                         pin_memory=True)

#     # buckets for low-dim projected grads (stored on CPU)
#     grads_per_class   = defaultdict(list)   # list[(Ni_c , ℓ)]
#     indices_per_class = defaultdict(list)

#     running_idx = 0
#     for X, Y in tqdm.tqdm(loader, desc="one-pass projected grads"):
#         X = X.to(device, non_blocking=True)
#         Y = Y.to(device, non_blocking=True)

#         # ----- split current mini-batch into micro-chunks ------------
#         B = Y.size(0)
#         for s in range(0, B, chunk_size_grad):
#             xc = X[s : s + chunk_size_grad]
#             yc = Y[s : s + chunk_size_grad]

#             # compute projected grad for **each** sample in micro-chunk
#             proj_chunk = torch.stack([
#                 _project_single_grad(model, xc[i], yc[i],
#                                      criterion, proj_matrix)
#                 for i in range(yc.size(0))
#             ])                                          # (m, ℓ) on GPU
#             proj_chunk_cpu = proj_chunk.cpu()           # immediately off-load

#             # bucket by class
#             for cls in range(num_classes):
#                 mask = (yc == cls)
#                 if mask.any():
#                     grads_per_class[cls].append(proj_chunk_cpu[mask.cpu()])
#                     base = running_idx + s
#                     idxs = torch.arange(base, base + yc.size(0))[mask.cpu()]
#                     indices_per_class[cls].append(idxs)

#             del proj_chunk, proj_chunk_cpu
#             torch.cuda.empty_cache()

#         running_idx += B

#     # ---------------- agreement scoring -----------------------------
#     selected = []
#     for cls in range(num_classes):
#         if cls not in grads_per_class:
#             continue                                    # class absent
#         G = torch.cat(grads_per_class[cls], dim=0)      # (Nc , ℓ)
#         I = torch.cat(indices_per_class[cls], dim=0)    # (Nc ,)

#         # cosine agreement with centroid
#         G_norm   = F.normalize(G, dim=1)
#         centroid = G_norm.mean(0)
#         # scores   = G_norm @ centroid
#         # top_k    = torch.topk(scores, samples_per_class).indices
#         # -----------------------------------------------------------------
#         # inside  for cls in range(num_classes):   (same place as before)
#         # -----------------------------------------------------------------
#         G_cls = torch.cat(grads_per_class[cls], dim=0)       # (Nc , ℓ)
#         I_cls = torch.cat(indices_per_class[cls], dim=0)     # (Nc ,)
        
#         G_norm   = torch.nn.functional.normalize(G_cls, dim=1)
#         centroid = G_norm.mean(dim=0)
#         scores   = G_norm @ centroid                         # (Nc,)
        
#         λ_div = 0.2            # diversity weight  (hyper-parameter)
#         k     = samples_per_class

#         chosen = []         # local indices inside G_cls
#         for _ in range(k):
#             if not chosen:                              # ---------- first pick
#                 idx = scores.argmax()                   # pure agreement
#             else:                                       # ---------- later picks
#                 idxs_t   = torch.tensor(chosen, device=G_norm.device)
#                 S        = G_norm[idxs_t]               # (p, ℓ)   p = len(chosen)
#                 if S.dim() == 1:                        # happens when p == 1
#                     S = S.unsqueeze(0)                  # make it (1, ℓ)
        
#                 sim_to_S  = (G_norm @ S.T).max(dim=1).values
#                 adj_score = scores - λ_div * sim_to_S
#                 idx       = adj_score.argmax()
        
#             chosen.append(idx.item())  
        
#         selected.extend(I_cls[torch.tensor(chosen)].tolist())
#         # -----------------------------------------------------------------

#         # selected.extend(I[top_k].tolist())

#     return selected


Without Diversity penalty

In [13]:
from __future__ import annotations
from collections import defaultdict
from typing      import List

import torch, tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F


# --------------------------------------------------------------
# utility: project the gradient of ONE sample to ℓ dims
# --------------------------------------------------------------
def _project_single_grad(
        model, x, y,
        criterion,
        proj_matrix: torch.Tensor      # (ℓ, D)  — on same device as model
) -> torch.Tensor:                    # (ℓ,)
    model.zero_grad(set_to_none=True)           # free old grads
    out   = model(x.unsqueeze(0)).squeeze(0)    # keep batch dim
    loss  = criterion(out, y)
    loss.backward()

    # flatten all parameter grads in registration order
    g_proj = torch.zeros(proj_matrix.size(0), device=proj_matrix.device)
    offset = 0
    for p in model.parameters():
        if p.grad is None:
            continue
        g_flat = p.grad.flatten().to(proj_matrix.dtype)          # (Pi,)
        P_slice = proj_matrix[:, offset: offset + g_flat.numel()]  # (ℓ, Pi)
        g_proj += P_slice @ g_flat                               # accumulate
        offset += g_flat.numel()
    return g_proj


# --------------------------------------------------------------
# main selector (OOM-proof, no functorch)
# --------------------------------------------------------------
def class_balanced_agreeing_subset_fast(
        model,
        dataset,
        num_classes        : int,
        samples_per_class  : int,
        criterion,
        device,
        proj_matrix        : torch.Tensor,   # (ℓ, D)  on same device as model
        batch_size_data    : int = 64,       # images per forward pass
        chunk_size_grad    : int = 4         # #images whose grads we keep at once
) -> List[int]:
    """
    Pick `samples_per_class` images per class with the highest
    gradient-agreement score, **without ever running out of GPU memory**.
    Returns a list of dataset indices.
    """
    model.eval()
    proj_matrix = proj_matrix.to(device)

    loader = DataLoader(dataset,
                        batch_size=batch_size_data,
                        shuffle=False,
                        num_workers=4,
                        pin_memory=True)

    # buckets for low-dim projected grads (stored on CPU)
    grads_per_class   = defaultdict(list)   # list[(Ni_c , ℓ)]
    indices_per_class = defaultdict(list)

    running_idx = 0
    for X, Y in tqdm.tqdm(loader, desc="one-pass projected grads"):
        X = X.to(device, non_blocking=True)
        Y = Y.to(device, non_blocking=True)

        # ----- split current mini-batch into micro-chunks ------------
        B = Y.size(0)
        for s in range(0, B, chunk_size_grad):
            xc = X[s : s + chunk_size_grad]
            yc = Y[s : s + chunk_size_grad]

            # compute projected grad for **each** sample in micro-chunk
            proj_chunk = torch.stack([
                _project_single_grad(model, xc[i], yc[i],
                                     criterion, proj_matrix)
                for i in range(yc.size(0))
            ])                                          # (m, ℓ) on GPU
            proj_chunk_cpu = proj_chunk.cpu()           # immediately off-load

            # bucket by class
            for cls in range(num_classes):
                mask = (yc == cls)
                if mask.any():
                    grads_per_class[cls].append(proj_chunk_cpu[mask.cpu()])
                    base = running_idx + s
                    idxs = torch.arange(base, base + yc.size(0))[mask.cpu()]
                    indices_per_class[cls].append(idxs)

            del proj_chunk, proj_chunk_cpu
            torch.cuda.empty_cache()

        running_idx += B

    # ---------------- agreement scoring -----------------------------
    selected = []
    for cls in range(num_classes):
        if cls not in grads_per_class:
            continue                                    # class absent
        G = torch.cat(grads_per_class[cls], dim=0)      # (Nc , ℓ)
        I = torch.cat(indices_per_class[cls], dim=0)    # (Nc ,)

        # cosine agreement with centroid
        G_norm   = F.normalize(G, dim=1)
        centroid = G_norm.mean(0)
        scores   = G_norm @ centroid
        top_k    = torch.topk(scores, samples_per_class).indices
        selected.extend(I[top_k].tolist())

    return selected


In [14]:
# from __future__ import annotations
# from typing import List
# import torch, tqdm
# from torch.utils.data import DataLoader
# import torch.nn.functional as F


# # ------------------------------------------------------------------
# # project ONE sample’s gradient into ℓ-dim space (unchanged)
# # ------------------------------------------------------------------
# def _project_single_grad(
#     model, x, y,
#     criterion,
#     proj_matrix: torch.Tensor        # (ℓ, D) on same device as model
# ) -> torch.Tensor:                   # (ℓ,)
#     model.zero_grad(set_to_none=True)
#     loss = criterion(model(x.unsqueeze(0)).squeeze(0), y)
#     loss.backward()

#     g_proj = torch.zeros(proj_matrix.size(0), device=proj_matrix.device)
#     offset = 0
#     for p in model.parameters():
#         if p.grad is None:
#             continue
#         g_flat  = p.grad.flatten().to(proj_matrix.dtype)
#         P_slice = proj_matrix[:, offset : offset + g_flat.numel()]
#         g_proj += P_slice @ g_flat
#         offset += g_flat.numel()
#     return g_proj


# # ------------------------------------------------------------------
# # GLOBAL agreeing-subset selector  (no per-class buckets)
# # ------------------------------------------------------------------
# def agreeing_subset_fast(
#     model,
#     dataset,
#     subset_size        : int,           # total samples to keep
#     criterion,
#     device,
#     proj_matrix        : torch.Tensor,  # (ℓ, D)  on same device
#     batch_size_data    : int = 64,
#     chunk_size_grad    : int = 4
# ) -> List[int]:
#     """
#     Pick `subset_size` images whose projected gradients have the highest
#     agreement with the global centroid.  OOM-safe and single-pass.
#     """
#     model.eval()
#     proj_matrix = proj_matrix.to(device)

#     loader = DataLoader(
#         dataset,
#         batch_size=batch_size_data,
#         shuffle=False,
#         num_workers=4,
#         pin_memory=True,
#     )

#     all_grads   = []            # list[(m, ℓ)] on CPU
#     all_indices = []            # matching dataset indices (tensor CPU)

#     running_idx = 0
#     for X, Y in tqdm.tqdm(loader, desc="one-pass projected grads"):
#         X = X.to(device, non_blocking=True)
#         Y = Y.to(device, non_blocking=True)

#         B = Y.size(0)
#         for s in range(0, B, chunk_size_grad):
#             xc = X[s : s + chunk_size_grad]
#             yc = Y[s : s + chunk_size_grad]

#             proj_chunk = torch.stack([
#                 _project_single_grad(model, xc[i], yc[i],
#                                      criterion, proj_matrix)
#                 for i in range(yc.size(0))
#             ])                                       # (m, ℓ) GPU
#             all_grads.append(proj_chunk.cpu())       # off-load
#             base = running_idx + s
#             all_indices.append(
#                 torch.arange(base, base + yc.size(0)).cpu()
#             )

#             del proj_chunk
#             torch.cuda.empty_cache()

#         running_idx += B

#     # ---------------- agreement scoring ---------------------------
#     G = torch.cat(all_grads,   dim=0)                # (N, ℓ)  CPU
#     I = torch.cat(all_indices, dim=0)                # (N,)    CPU

#     G_norm   = F.normalize(G, dim=1)
#     centroid = G_norm.mean(0)
#     scores   = G_norm @ centroid                     # (N,)

#     top = torch.topk(scores, subset_size).indices
#     return I[top].tolist()


In [15]:
# --------------------- Train/Eval ---------------------
def train(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct = 0, 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        correct += (outputs.argmax(1) == targets).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)
            correct += (outputs.argmax(1) == targets).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [16]:

# --------------------- Main ---------------------
subset_size = 2500
num_classes = 100
samples_per_class = subset_size // num_classes
# model_for_grad = resnet18(num_classes=100).to(device)
from torchvision.models import resnet18, ResNet18_Weights
weights = ResNet18_Weights.IMAGENET1K_V1      # or V2
model_for_grad = resnet18(weights=weights)

# 2. replace the classification head for your 100-class task
in_features = model_for_grad.fc.in_features
model_for_grad.fc = torch.nn.Linear(in_features, num_classes)

# 3. (optional) freeze lower layers if you only care about the head
for param in model_for_grad.parameters():
    param.requires_grad = False
for param in model_for_grad.fc.parameters():
    param.requires_grad = True

model_for_grad = model_for_grad.to(device)
# model_for_grad = convert_bn_to_gn(model_for_grad).to(device)
criterion = nn.CrossEntropyLoss()

# model_for_grad = SimpleMLP().to(device)
# criterion = nn.CrossEntropyLoss()



In [17]:
samples_per_class

25

In [18]:
from torch.utils.data import DataLoader

batch_size_data = 64            # image batch size (tune to your GPU RAM)
train_loader   = DataLoader(
    train_dataset,
    batch_size=batch_size_data,
    shuffle=False,               # no need to shuffle for the sketch
    num_workers=4,               # >0 = load data in parallel
    pin_memory=True,             # faster host→device copies
)


In [19]:
import torch.nn.functional as F
import numpy as np

def per_sample_grads_slow(model, x, y):
    """
    Returns a (B, D) NumPy array: one gradient vector per sample.
    Works with any model and loss.
    """
    model.eval() 
    B = x.size(0)
    grads = []

    for i in range(B):
        model.zero_grad(set_to_none=True)
        logits = model(x[i : i + 1])          # keep batch dim
        loss   = F.cross_entropy(logits, y[i : i + 1])
        loss.backward()

        # flatten all parameter grads into one long vector
        g = torch.cat([
            p.grad.flatten() for p in model.parameters() if p.requires_grad
        ]).cpu().numpy()
        grads.append(g)
        torch.cuda.empty_cache()  
    return np.stack(grads, axis=0)            # (B, D)

# def per_sample_proj_grad(
#         model, x, y, criterion,
#         proj_matrix: torch.Tensor | None = None,
#         dtype: torch.dtype = torch.float16,          # keeps slices small
# ) -> torch.Tensor:
#     """
#     • If `proj_matrix` is given → returns (ℓ,) projected gradient.
#     • If `proj_matrix is None` → returns the *full* flattened gradient (D,).
#     Both results live on **CPU**; only tiny temporaries sit on GPU.
#     """
#     model.zero_grad(set_to_none=True)
#     out  = model(x.unsqueeze(0)).squeeze(0)
#     loss = criterion(out, y)
#     loss.backward()

#     # flatten all parameter grads
#     g_flat = torch.cat([p.grad.flatten().to(dtype)
#                         for p in model.parameters() if p.grad is not None])

#     if proj_matrix is None:                # bootstrap mode
#         return g_flat.cpu()                # (D,)
#     else:                                  # normal projected mode
#         proj_matrix = proj_matrix.to(dtype)
#         g_proj = (proj_matrix @ g_flat).cpu()   # (ℓ,)
#         return g_proj




In [20]:
# d = 256                      # rows of the projection / sketch
# # fd = FDStreamer(d)         # create once
# # criterion = nn.CrossEntropyLoss(reduction='none')
# fd = FDStreamer(
#         d,
#         batch_size = 32,           # ①  never accumulate >2 rows
#         dtype       = torch.float16, #    rows are 2× smaller   
# )  
# # model_for_grad.to(device)
# # model_for_grad.eval()        # (no dropout / BN stats)



# # --------- configurable quota -----------------------------------
# quota_per_class = None      # None → use *all* samples
# # quota_per_class = 10      # e.g. set to 10 to keep ≤10 rows / class
# # ---------------------------------------------------------------

# seen_count = [0] * num_classes        # per-class counter

# for x, y in tqdm.tqdm(train_loader, desc="stream rows into sketch"):
#     if quota_per_class is None:
#         # --- no quota: take the whole batch --------------------
#         x_sub, y_sub = x.to(device), y.to(device)

#     else:
#         # --- quota active: mask out classes already full -------
#         need_mask = torch.tensor(
#             [seen_count[c.item()] < quota_per_class for c in y]
#         )
#         if not need_mask.any():          # batch gives nothing new
#             continue

#         x_sub, y_sub = x[need_mask].to(device), y[need_mask].to(device)

#     # --------- compute projected gradients ----------------------
#     grads = per_sample_grads_slow(model_for_grad, x_sub, y_sub)
#     fd.add(grads)                       # stream into FD sketch

#     # --------- update counters ----------------------------------
#     if quota_per_class is not None:
#         for c in y_sub.cpu():
#             seen_count[c] += 1
#         if min(seen_count) >= quota_per_class:
#             break                       




In [21]:
# B_sketch    = fd.finalize()          # shape (d, D)  -- NumPy CPU
# proj_matrix = torch.from_numpy(B_sketch)   

In [22]:

# final_subset_idx = class_balanced_agreeing_subset_fast(
#     model_for_grad,
#     train_dataset,
#     num_classes=num_classes,
#     samples_per_class=samples_per_class,
#     criterion=criterion,
#     device=device,
#     proj_matrix=proj_matrix,
#     batch_size_data=128,           # data loader batch
#     chunk_size_grad=64             # NEW — vmap micro-batch
# )

# final_subset = Subset(train_dataset, final_subset_idx)
# final_loader = DataLoader(final_subset, batch_size=64, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [23]:

from timm.data.mixup import Mixup
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from model_factory import create_model

num_classes = 100
cut_mix = True   # Set to True to enable Mixup/CutMix

model = create_model(
    "resnext",    # alias → cifar_resnext29_32x4d
    num_classes=num_classes,
    pretrained=False
).to(device)

# ------------------------------------------------------------------
# 2.  Augment helpers
# ------------------------------------------------------------------
if cut_mix:
    mixup_fn = Mixup(
        mixup_alpha   = 0.2,
        cutmix_alpha  = 1.0,
        prob          = 1.0,      # always apply either MixUp or CutMix
        switch_prob   = 0.5,      # 50-50 split
        label_smoothing = 0.1,
        num_classes   = num_classes,
    )
    criterion = nn.CrossEntropyLoss(reduction='none')   # per-sample for MixUp/CutMix
else:
    mixup_fn = None
    criterion = nn.CrossEntropyLoss()   # standard

# ------------------------------------------------------------------
# 3.  Optimiser & cosine schedule
# ------------------------------------------------------------------
optimizer  = optim.SGD(model.parameters(),
                       lr=0.1, momentum=0.9, weight_decay=5e-4)
epochs     = 200
scheduler  = CosineAnnealingLR(optimizer, T_max=epochs)
scaler     = torch.cuda.amp.GradScaler()  # AMP

# ------------------------------------------------------------------
# 4.  Training / eval loops
# ------------------------------------------------------------------
def train_one_epoch(model, loader):
    model.train()
    loss_sum, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        if cut_mix:
            x, y = mixup_fn(x, y)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            logits = model(x)
            # Mixup/CutMix: criterion returns (N,) vector, else scalar
            if cut_mix:
                loss = criterion(logits, y).mean()
            else:
                loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_sum += loss.item() * x.size(0)
        total    += x.size(0)
        with torch.no_grad():
            if cut_mix:
                # y is soft (one-hot or mixup), accuracy by argmax of both
                preds = logits.argmax(1)
                targets = y.argmax(1)
                correct += (preds == targets).sum().item()
            else:
                preds = logits.argmax(1)
                correct += (preds == y).sum().item()
    return loss_sum / total, correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ce       = nn.CrossEntropyLoss()
    loss_sum, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss_sum += ce(logits, y).item() * x.size(0)
        total    += x.size(0)
        correct  += (logits.argmax(1) == y).sum().item()
    return loss_sum / total, correct / total


# ------------------------------------------------------------------
# 5.  Train                                                                   
# ------------------------------------------------------------------

    
    
# for epoch in range(1, epochs + 1):
#     train_loss, train_acc = train_one_epoch(model, final_loader)
#     val_loss,   val_acc   = evaluate(model, test_loader)

#     scheduler.step()            # cosine annealing step

#     lr_now = scheduler.get_last_lr()[0]
#     print(f"[{epoch:3d}/{epochs}] lr {lr_now:7.5f} | "
#           f"train {train_acc*100:5.2f}% | val {val_acc*100:5.2f}%")

In [24]:
def build_sketch(model, loader, proj_matrix=None, d=256,
                 batch_size_fd=32, device="cuda"):
    fd = FDStreamer(d, batch_size=batch_size_fd, dtype=torch.float16)
    for xb, yb in tqdm.tqdm(loader, desc="stream rows into sketch"):
        xb, yb = xb.to(device), yb.to(device)
        rows = per_sample_grads_slow(model, xb, yb)
        fd.add(rows)
    return torch.from_numpy(fd.finalize())        # (d, D)  on CPU


In [25]:
# ---------------- hyper-parameters ---------------------------------
T_REFRESH    = epochs         # epochs between refreshes
SUBSET_FRACTION = 0.05   # e.g. 5% of training set
SKETCH_L     = 256       # FD width
BATCH_DATA   = 200
CHUNK_GRAD   = 32
NUM_EPOCHS   = epochs

# New parameters:
WARMUP_EPOCHS         = 0      # Set >0 for warmup, 0 for immediate selection
FIRST_SELECTION_ON_INIT = True # If True, select at epoch 1, else after warmup
# -------------------------------------------------------------------

full_loader  = DataLoader(train_dataset, batch_size=BATCH_DATA, shuffle=True)
train_loader = full_loader                       # start with full data
val_loader   = DataLoader(test_dataset, batch_size=128, shuffle=False)

if WARMUP_EPOCHS > 0:
    first_selection_epoch = WARMUP_EPOCHS + 1
elif FIRST_SELECTION_ON_INIT:
    first_selection_epoch = 1
else:
    raise ValueError("Either WARMUP_EPOCHS > 0 or FIRST_SELECTION_ON_INIT must be True")

for epoch in range(1, NUM_EPOCHS + 1):

    # --- Single, clear selection logic ---
    do_selection = False
    if epoch == first_selection_epoch:
        if WARMUP_EPOCHS > 0:
            print(f"\n⇢ Finished {WARMUP_EPOCHS} warmup epochs. Selecting subset at epoch {epoch}")
        else:
            print(f"\n⇢ Selecting subset at epoch {epoch}")
        do_selection = True
    elif epoch > first_selection_epoch and (epoch - first_selection_epoch) % T_REFRESH == 0:
        print(f"\n⇢ Refreshing subset at epoch {epoch}")
        do_selection = True

    if do_selection:
        B_sketch = build_sketch(model_for_grad, full_loader,
                                proj_matrix=None, d=SKETCH_L,
                                batch_size_fd=BATCH_DATA, device=device)
        proj_matrix = B_sketch.to(device, torch.float16)
        k_per_cls = int(SUBSET_FRACTION * len(train_dataset) / num_classes)
        new_ids = class_balanced_agreeing_subset_fast(
            model_for_grad, train_dataset,
            num_classes, k_per_cls,
            criterion, device, proj_matrix,
            batch_size_data=BATCH_DATA, chunk_size_grad=CHUNK_GRAD
        )
        subset_ds   = Subset(train_dataset, new_ids)
        train_loader = DataLoader(subset_ds, batch_size=BATCH_DATA, shuffle=True)
        print(f"   new subset size: {len(subset_ds)}")
    # ---- training ----
    # train_one_epoch(model, train_loader)
    # val_loss, val_acc = evaluate(model, val_loader)
    # print(f"[{epoch}] val-acc={val_acc:5.2%}")
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss,   val_acc   = evaluate(model, val_loader)

    scheduler.step()            # cosine annealing step

    lr_now = scheduler.get_last_lr()[0]
    print(f"[{epoch:3d}/{epochs}] lr {lr_now:7.5f} | "
          f"train {train_acc*100:5.2f}% | val {val_acc*100:5.2f}%")



⇢ Selecting subset at epoch 1


stream rows into sketch: 100%|████████████████| 250/250 [19:29<00:00,  4.68s/it]
one-pass projected grads: 100%|███████████████| 250/250 [02:31<00:00,  1.65it/s]


   new subset size: 2500
[  1/200] lr 0.09999 | train  0.92% | val  1.38%
[  2/200] lr 0.09998 | train  1.36% | val  1.01%
[  3/200] lr 0.09994 | train  1.72% | val  1.19%
[  4/200] lr 0.09990 | train  1.56% | val  2.05%
[  5/200] lr 0.09985 | train  3.20% | val  3.79%
[  6/200] lr 0.09978 | train  4.44% | val  4.91%
[  7/200] lr 0.09970 | train  6.20% | val  5.86%
[  8/200] lr 0.09961 | train  6.72% | val  6.07%
[  9/200] lr 0.09950 | train  8.00% | val  7.27%
[ 10/200] lr 0.09938 | train  6.24% | val  7.32%
[ 11/200] lr 0.09926 | train  8.88% | val  7.97%
[ 12/200] lr 0.09911 | train  8.40% | val  7.86%
[ 13/200] lr 0.09896 | train 10.24% | val  8.45%
[ 14/200] lr 0.09880 | train  9.04% | val  8.97%
[ 15/200] lr 0.09862 | train 11.28% | val  8.91%
[ 16/200] lr 0.09843 | train 10.68% | val  9.42%
[ 17/200] lr 0.09823 | train 10.16% | val  9.79%
[ 18/200] lr 0.09801 | train 11.84% | val  9.56%
[ 19/200] lr 0.09779 | train 11.28% | val  9.60%
[ 20/200] lr 0.09755 | train 12.92% | val 10