# Demonstration of Data Distill by Neural Collapse algorithms

This notebook implements data distillation with neural collapse. The main papers considered here are https://github.com/SsnL/dataset-distillation and https://github.com/tding1/Neural-Collapse. The neural network is firstly trained to its terminal phase and then distills synthesized data as introduced. Below we configure our path to the dataset distillation packages from https://github.com/SsnL/dataset-distillation.

In [20]:
import sys
sys.path.append('/Users/songzeyang/Documents/GitHub/dataset-distillation/utils')
sys.path.append('/Users/songzeyang/Documents/GitHub/Neural-Collapse')

In [21]:
import sys
sys.path

['/Users/songzeyang/Documents/GitHub/Data_Distill_with_NC',
 '/Users/songzeyang/anaconda3/lib/python311.zip',
 '/Users/songzeyang/anaconda3/lib/python3.11',
 '/Users/songzeyang/anaconda3/lib/python3.11/lib-dynload',
 '',
 '/Users/songzeyang/anaconda3/lib/python3.11/site-packages',
 '/Users/songzeyang/anaconda3/lib/python3.11/site-packages/aeosa',
 '/Users/songzeyang/Documents/GitHub/dataset-distillation/utils',
 '/Users/songzeyang/Documents/GitHub/Neural-Collapse',
 '/Users/songzeyang/Documents/GitHub/dataset-distillation/utils',
 '/Users/songzeyang/Documents/GitHub/Neural-Collapse',
 '/Users/songzeyang/Documents/GitHub/dataset-distillation/utils',
 '/Users/songzeyang/Documents/GitHub/Neural-Collapse']

# Import

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST
from torch.utils.data import DataLoader, Dataset
import numpy as np
import copy
from contextlib import contextmanager

from six import add_metaclass
from typing import Type, Any, Callable, Union, List, Optional

Let's import the file from the https://github.com/tding1/Neural-Collapse.

In [23]:
import models
from utils import *
from datasets import make_dataset

ImportError: cannot import name 'make_dataset' from 'datasets' (/Users/songzeyang/anaconda3/lib/python3.11/site-packages/datasets/__init__.py)

Let's import the file from the https://github.com/SsnL/dataset-distillation.

In [24]:
# from basics import task_loss, final_objective_loss, evaluate_steps
# from utils.distributed import broadcast_coalesced, all_reduce_coalesced
# from utils.io import save_results

# Define Parameters

In [25]:
# architecture params
model='resnet18'

# dataset
dataset='mnist'
data_dir='~/data'

# training params
optimizer="LBFGS"
lr=0.1
history_size=10
batch_size=128
uid="result"
device = "cpu"
SOTA=False

# Network params
loss = 'CrossEntropy'
bias=True
ETF_fc=False
fixdim=0

# Load Data

In [None]:
trainloader, _, num_classes = make_dataset(dataset, 
                                           data_dir, 
                                           batch_size, 
                                           SOTA=SOTA)
print(num_classes)

# Load Model

As we are primarily focusing on the neural collapse, we will build a large Resnet18 model to show the nerual collapse. Let's built a wrap class to facilitate the distillation.

In [None]:
model = models.__dict__[model](num_classes=num_classes, 
                                    fc_bias=bias, 
                                    ETF_fc=ETF_fc, 
                                    fixdim=fixdim, 
                                    SOTA=SOTA).to(device)

print('# of model parameters: ' + str(count_network_parameters(model)))
print(type(model))

In [None]:
class ExtendedResNet(ResNet):  # Assuming your model is ResNet
    def __init__(self, *args, **kwargs):
        super(ExtendedResNet, self).__init__(*args, **kwargs)

    def get_params(self):
        """
        Reture the flat version of parameter as in the data distillation paper, 
        given a model subclass nn.module
        """
        w_modules_names = []
    
        for m in self.modules():
            for n, p in m.named_parameters(recurse=False):
                if p is not None:
                    w_modules_names.append((m, n))
            for n, b in m.named_buffers(recurse=False):
                if b is not None:
                    print("The buffer will be treated as a constant and assumed not to change during gradient steps.")

        self.weights_module_names = tuple(w_modules_names)
        
        # Put to correct device before we do stuff on parameters
#         self = self.to(self.device)

        ws = tuple(m._parameters[n].detach() for m, n in w_modules_names)
        
        print(len(set(w.dtype for w in ws)))

        assert len(set(w.dtype for w in ws)) == 1

        # reparam to a single flat parameter
        self.weights_numels = tuple(w.numel() for w in ws)
        self.weights_shapes = tuple(w.shape for w in ws)
        with torch.no_grad():
            flat_w = torch.cat([w.reshape(-1) for w in ws], 0)

        # remove old parameters, assign the names as buffers
        for m, n in self.weights_module_names:
            delattr(m, n)
            m.register_buffer(n, None)

        # register the flat one
        if not hasattr(self, 'attribute_name'):
            self.register_parameter('flat_w', nn.Parameter(flat_w, requires_grad=True))

        return self.flat_w
    
    @contextmanager
    def unflatten_weight(self, flat_w):
        ws = (t.view(s) for (t, s) in zip(flat_w.split(self.weights_numels), self.weights_shapes))
        for (m, n), w in zip(self.weights_module_names, ws):
            setattr(m, n, w)
        yield
        for m, n in self.weights_module_names:
            setattr(m, n, None)
            
    def forward_with_param(self, inp, new_w):
        with self.unflatten_weight(new_w):
#             return self.model(self, inp)
            return nn.Module.__call__(self, inp)


In [None]:
model = ExtendedResNet(block=BasicBlock, layers=[2, 2, 2, 2], 
                       num_classes=num_classes, 
                       fc_bias=bias,
                       ETF_fc=ETF_fc,
                       fixdim=fixdim,
                       SOTA=SOTA)
print('# of model parameters: ' + str(count_network_parameters(model)))
print(type(model))

However, this large model is very hard to train and we can use a smaller model to test our code base.

In [None]:
class LeNet(nn.Module):
    supported_dims = {28, 32}

    def __init__(self, nc, num_classes, input_size):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(nc, 6, 5, padding=2 if input_size == 28 else 0)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1 if num_classes <= 2 else num_classes)

    def forward(self, x):
        out = F.relu(self.conv1(x), inplace=True)
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out), inplace=True)
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out), inplace=True)
        out = F.relu(self.fc2(out), inplace=True)
        out = self.fc3(out)
        return out

# Data Distillation Algorithm with neural collapse

In [None]:
# distillation settings
distill_steps=10
distill_epochs=3
distilled_images_per_class_per_step=1
distill_lr=0.02

In [None]:
images, labels = next(iter(trainloader))
size_train, channels, height, width = images.shape
num_classes = torch.unique(labels)
print(num_classes)
print(labels.shape)
print(size_train, channels, height, width)

In [None]:
class Distiller:
    def __init__(self, model, device = "cpu", epochs = 400, history_size = 10, weight_decay = 5e-4, lr = 0.1,
                 data = trainloader, size_train = size_train, channels = channels, height = height, width = width,
                 distill_steps=10, 
                 distill_epochs=3, 
                 distilled_images_per_class_per_step=1,
                 distill_lr=0.02,
                 decay_factor = 0.5
                ):
        self.model = model
        self.device = device
        self.epochs = epochs
        self.history_size = history_size
        self.weight_decay = weight_decay
        self.lr = lr
        # unpack the dataset
        self.train_loader = trainloader
        self.size_train = size_train
        self.channels = channels
        self.height = height
        self.width = width
        self.num_classes = len(torch.unique(labels))
        # distill setting
        self.num_data_steps = distill_steps  # how much data we have
        self.distill_epochs = distill_epochs
        self.distilled_images_per_class_per_step = distilled_images_per_class_per_step
        self.distill_lr = distill_lr
        self.decay_factor = decay_factor
        self.T = distill_steps * distill_epochs  # how many sc steps we run
        self.num_per_step = self.num_classes * distilled_images_per_class_per_step
        assert distill_lr >= 0, 'distill_lr must >= 0'
        self.init_data_optim()
#         self.model.unflatten_weight = unflatten_weight.__get__(self.model)
        
    def init_data_optim(self, lr=lr):
        self.params = []
        optim_lr = lr

        # labels
        self.labels = []
        distill_label = torch.arange(self.num_classes, dtype=torch.long, device=self.device) \
                             .repeat(self.distilled_images_per_class_per_step, 1)  # [[0, 1, 2, ...], [0, 1, 2, ...]]
        distill_label = distill_label.t().reshape(-1)  # [0, 0, ..., 1, 1, ...]
        for _ in range(self.num_data_steps):
            self.labels.append(distill_label)
        self.all_labels = torch.cat(self.labels)

        # data
        self.data = []
        for _ in range(self.num_data_steps):
            distill_data = torch.randn(self.num_per_step, self.channels, self.height, self.width,
                                       device=self.device, requires_grad=True)
            self.data.append(distill_data)
            self.params.append(distill_data)
        # lr

        # undo the softplus + threshold
        raw_init_distill_lr = torch.tensor(self.distill_lr, device=self.device)
        raw_init_distill_lr = raw_init_distill_lr.repeat(self.T, 1)
        self.raw_distill_lrs = raw_init_distill_lr.expm1_().log_().requires_grad_()
        self.params.append(self.raw_distill_lrs)

        assert len(self.params) > 0, "must have at least 1 parameter"
        
        optimizer_function = optim.LBFGS
        kwargs = {'lr': self.lr,
                  'history_size': self.history_size,
                  'line_search_fn': 'strong_wolfe'
        }

        self.optimizer = optimizer_function(self.params, **kwargs)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=self.distill_epochs,
                                                   gamma=self.decay_factor)
        for p in self.params:
            p.grad = torch.zeros_like(p)
            
    def weight_decays(self, model):
        penalty = 0
        for p in model.parameters():
            if p.requires_grad:
                penalty += 0.5 * self.weight_decay * torch.norm(p) ** 2
        return penalty.to(self.device)
            
    def get_steps(self):
        data_label_iterable = (x for _ in range(self.distill_epochs) for x in zip(self.data, self.labels))
        lrs = F.softplus(self.raw_distill_lrs).unbind()

        steps = []
        for (data, label), lr in zip(data_label_iterable, lrs):
            steps.append((data, label, lr))

        return steps
    

    
    def forward(self, model, rdata, rlabel, steps):

        # forward
        model.train()
        w = model.get_params()
        params = [w]
        gws = []
        
        criterion = nn.CrossEntropyLoss()
        print("loss function is made")
        
        # inner loop
        for step_i, (data, label, lr) in enumerate(steps):
            """
            Begining of the inner loop. Input: distill data, output: distill labels
            goal: store the weight for future updating.
            """
            print("Begining of the inner loop number: ", step_i+1)
            
            data, label = data.to(self.device), label.to(self.device)
            
            with torch.enable_grad():
                output = model.forward_with_param(data, w)
                loss = criterion(output[0], label) + self.weight_decays(model)
#             gw,_ = torch.autograd.grad(loss, w, lr.squeeze(), create_graph=True)
            gw, = torch.autograd.grad(loss, w, lr.squeeze(), create_graph=True)

            with torch.no_grad():
                new_w = w.sub(gw).requires_grad_() # minus the gw weighted by lr.squeeze()
                params.append(new_w)
                gws.append(gw)
                w = new_w
        
        # final L
        model.eval()
        output = model.forward_with_param(rdata, params[-1])
        ll = criterion(output[0], rlabel) + self.weight_decays(model)
        print("The loss now is ", ll, "in the forward mode")
        return ll, (ll, params, gws)
    
    def backward(self, model, rdata, rlabel, steps, saved_for_backward):
        l, params, gws = saved_for_backward

        datas = []
        gdatas = []
        lrs = []
        glrs = []

        dw, = torch.autograd.grad(l, (params[-1],))

        # backward
        model.train()
       
        for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):
            
            hvp_in = [w]
            hvp_in.append(data)
            hvp_in.append(lr)
            dgw = dw.neg()  # gw is already weighted by lr, so simple negation
            hvp_grad = torch.autograd.grad(
                outputs=(gw,),
                inputs=hvp_in,
                grad_outputs=(dgw,)
            )
            # Update for next iteration, i.e., previous step
            with torch.no_grad():
                # Save the computed gdata and glrs
                datas.append(data)
                gdatas.append(hvp_grad[1])
                lrs.append(lr)
                glrs.append(hvp_grad[2])

                # Update for next iteration, i.e., previous step
                # Update dw
                # dw becomes the gradients w.r.t. the updated w for previous step
                dw.add_(hvp_grad[0])

        return datas, gdatas, lrs, glrs

    def accumulate_grad(self, grad_infos):
        bwd_out = []
        bwd_grad = []
        for datas, gdatas, lrs, glrs in grad_infos:
            bwd_out += list(lrs)
            bwd_grad += list(glrs)
            for d, g in zip(datas, gdatas):
                if d.grad is None:
                    d.grad = g.clone()
                else:
                    d.grad.add_(g)
        if len(bwd_out) > 0:
            torch.autograd.backward(bwd_out, bwd_grad)
            
    def prefetch_train_loader_iter(self):

        device = self.device
        train_iter = iter(self.train_loader)
        for epoch in range(self.epochs):
            niter = len(train_iter)
            prefetch_it = max(0, niter - 2)
            for it, val in enumerate(train_iter):
                # Prefetch (start workers) at the end of epoch BEFORE yielding
                if it == prefetch_it and epoch < self.epochs - 1:
                    train_iter = iter(self.train_loader)
                yield epoch, it, val
                
    def train(self):

        counts = 0

        for epoch, it, (rdata, rlabel) in self.prefetch_train_loader_iter():

            if it == 0:
                self.scheduler.step()

            if it == 0 and epoch == 0:
                with torch.no_grad():
                    steps = self.get_steps()

            self.optimizer.zero_grad()
            rdata, rlabel = rdata.to(self.device, non_blocking=True), rlabel.to(self.device, non_blocking=True)

            losses = []
            steps = self.get_steps()

            # activate everything needed to run on this process
            grad_infos = []

            l, saved = self.forward(model, rdata, rlabel, steps)
            losses.append(l.detach())
            grad_infos.append(self.backward(model, rdata, rlabel, steps, saved))
            del l, saved
            bwd_out = []
            bwd_grad = []
            for datas, gdatas, lrs, glrs in grad_infos:
                bwd_out += list(lrs)
                bwd_grad += list(glrs)
                for d, g in zip(datas, gdatas):
                    if d.grad is None:
                        d.grad = g.clone()
                    else:
                        d.grad.add_(g)
            if len(bwd_out) > 0:
                torch.autograd.backward(bwd_out, bwd_grad)

            # all reduce if needed
            # average grad
            all_reduce_tensors = [p.grad for p in self.params]
   
            # opt step
            self.optimizer.step()

            del steps, grad_infos, losses, all_reduce_tensors


        with torch.no_grad():
            steps = self.get_steps()
        self.save_results(steps)
        return steps


In [None]:
Distiller(model).train()