# Hyperparameter Sweep

This notebook implements an hyperparameter sweep for data distillation algorithm utilizing the recent finding in neural collapse. 

The main papers considered here are 
 - data distillation:
     - https://github.com/SsnL/dataset-distillation 
     - https://github.com/google-research/google-research/tree/master/kip
 - Neural Collapse:
     - https://github.com/tding1/Neural-Collapse. 

The neural network has the option to fix the last layer weight matrix to be a simplex ETF. The ETF is a benign optimization landscape empeirically observed in practice as long as the network enters its terminal phase.



# Import

In [31]:
import wandb
wandb.login()

True

In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
import torch.nn.functional as F
import numpy as np

Let's import the file from the https://github.com/tding1/Neural-Collapse.

In [33]:
import models
from models.res_adapt import ResNet18_adapt
from utils import *

from train_2nd_order import weight_decay, trainer
from validate_NC import compute_Wh_b_relation, compute_W_H_relation, compute_ETF, compute_Sigma_B, compute_Sigma_W,compute_info,FCFeatures

from data.datasets import make_dataset
from arg_loader import *

# Load Parameters

In [34]:
# architecture params
model='resnet18'
bias=True
ETF_fc=False
fixdim=0
SOTA=False

# MLP settings (only when using mlp and res_adapt(in which case only width has effect))
width=1024
depth=6

# hardware settings
gpu_id=0
seed=6
use_cudnn=True

# dataset
dataset='cifar10'
data_dir='~/data'
uid="tmp"
force=True

# learning options
epochs = 2
batch_size=32
loss = 'CrossEntropy'
sample_size = None

# optimization
lr=0.05
optimizer = "SGD"
history_size=10
device = "mps"
check = False

In [35]:
args = train_args(model=model,bias=bias,ETF_fc=ETF_fc,fixdim=fixdim,SOTA=SOTA,
                  width=width,depth=depth,
                  gpu_id=gpu_id,seed=seed,use_cudnn=use_cudnn,
                  dataset=dataset,data_dir=data_dir,uid=uid,force=force,
                  epochs=epochs,batch_size = batch_size,loss = loss,sample_size=sample_size,
                  lr = lr,optimizer=optimizer,history_size=history_size, 
                  device = device)

override this uidtmp
<arg_loader.train_args object at 0x296311010>
cudnn is used


In [36]:
if device == "cuda":
    torch.cuda.empty_cache()
if check:
    torch.cuda.memory_allocated()

# Define function to train

In [37]:
def trainer(args, model, trainloader, epoch_id, criterion, optimizer):

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_idx, (inputs, targets) in enumerate(trainloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)

        model.train()

        def closure():
            outputs = model(inputs)

            if args.loss == 'CrossEntropy':
                loss = criterion(outputs[0], targets) + weight_decay(args, model)
            elif args.loss == 'MSE':
                loss = criterion(outputs[0], nn.functional.one_hot(targets).type(torch.FloatTensor).to(args.device)) \
                       + weight_decay(args, model)

            optimizer.zero_grad()
            loss.backward()

            return loss

        optimizer.step(closure)

        # measure accuracy and record loss
        model.eval()
        outputs = model(inputs)
        prec1, prec5 = compute_accuracy(outputs[0].data, targets.data, topk=(1, 5))

        if args.loss == 'CrossEntropy':
            loss = criterion(outputs[0], targets) + weight_decay(args, model)
        elif args.loss == 'MSE':
            loss = criterion(outputs[0], nn.functional.one_hot(targets).type(torch.FloatTensor).to(args.device)) \
                   + weight_decay(args, model)

        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # if batch_idx % 10 == 0:
        print('[epoch: %d] (%d/%d) | Loss: %.4f | top1: %.4f | top5: %.4f ' %
              (epoch_id + 1, batch_idx + 1, len(trainloader), losses.avg, top1.avg, top5.avg))
                
    
#     wandb.log({
#         "losses.avg":losses.avg, 
#         "top1.avg":top1.avg,
#         "top5.avg":top5.avg
#     })


In [38]:
def evaluate_NC(args,model,testloader):
    
    args.load_path = "model_weights/tmp/"

    if args.load_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    fc_features = FCFeatures()
    model.fc.register_forward_pre_hook(fc_features)
    info_dict = {
            'collapse_metric': [],
            'ETF_metric': [],
            'WH_relation_metric': [],
            'Wh_b_relation_metric': [],
            'W': [],
            'b': [],
            'H': [],
            'mu_G_train': [],
            # 'mu_G_test': [],
            'train_acc1': [],
            'train_acc5': [],
            'test_acc1': [],
            'test_acc5': []
        }

    logfile = open('%s/test_log.txt' % (args.load_path), 'w')
    for i in range(args.epochs):

        model.load_state_dict(torch.load(args.load_path + 'epoch_' + str(i + 1).zfill(3) + '.pth'))
        model.eval()

        for n, p in model.named_parameters():
            if 'fc.weight' in n:
                W = p
            if 'fc.bias' in n:
                b = p

        mu_G_train, mu_c_dict_train, train_acc1, train_acc5 = compute_info(args, model, fc_features, trainloader, isTrain=True)
        mu_G_test, mu_c_dict_test, test_acc1, test_acc5 = compute_info(args, model, fc_features, testloader, isTrain=False)

        Sigma_W = compute_Sigma_W(args, model, fc_features, mu_c_dict_train, trainloader, isTrain=True)
        # Sigma_W_test_norm = compute_Sigma_W(args, model, fc_features, mu_c_dict_train, testloader, isTrain=False)
        Sigma_B = compute_Sigma_B(mu_c_dict_train, mu_G_train)

        collapse_metric = np.trace(Sigma_W @ scilin.pinv(Sigma_B)) / len(mu_c_dict_train)
        ETF_metric = compute_ETF(W)
        WH_relation_metric, H = compute_W_H_relation(W, mu_c_dict_train, mu_G_train)
        if args.bias:
            Wh_b_relation_metric = compute_Wh_b_relation(W, mu_G_train, b)
        else:
            Wh_b_relation_metric = compute_Wh_b_relation(W, mu_G_train, torch.zeros((W.shape[0], )))

        info_dict['collapse_metric'].append(collapse_metric)
        info_dict['ETF_metric'].append(ETF_metric)
        info_dict['WH_relation_metric'].append(WH_relation_metric)
        info_dict['Wh_b_relation_metric'].append(Wh_b_relation_metric)

        info_dict['W'].append((W.detach().cpu().numpy()))
        if args.bias:
            info_dict['b'].append(b.detach().cpu().numpy())
        info_dict['H'].append(H.detach().cpu().numpy())

        info_dict['mu_G_train'].append(mu_G_train.detach().cpu().numpy())
        # info_dict['mu_G_test'].append(mu_G_test.detach().cpu().numpy())

        info_dict['train_acc1'].append(train_acc1)
        info_dict['train_acc5'].append(train_acc5)
        info_dict['test_acc1'].append(test_acc1)
        info_dict['test_acc5'].append(test_acc5)

        

    print_and_save('[epoch: %d] | train top1: %.4f | train top5: %.4f | test top1: %.4f | test top5: %.4f ' %
                    (i + 1, train_acc1, train_acc5, test_acc1, test_acc5), logfile)
    
    wandb.log({
                   "train_acc1":train_acc1, 
                   "train_acc5":train_acc5,
                   "test_acc1":test_acc1,
                   "test_acc5":test_acc5
                })
    
    wandb.log({"collapse_metric":collapse_metric, 
                   "ETF_metric":ETF_metric, 
                   "WH_relation_metric":WH_relation_metric,
                   "Wh_b_relation_metric":Wh_b_relation_metric
                })


    with open(args.load_path + 'info.pkl', 'wb') as f:
        pickle.dump(info_dict, f)

In [39]:
def evaluater_NC(args,model,testloader,fc_features,info_dict):
    
    model.eval()

    for n, p in model.named_parameters():
        if 'fc.weight' in n:
            W = p
        if 'fc.bias' in n:
            b = p

    mu_G_train, mu_c_dict_train, train_acc1, train_acc5 = compute_info(args, model, fc_features, trainloader, isTrain=True)
    mu_G_test, mu_c_dict_test, test_acc1, test_acc5 = compute_info(args, model, fc_features, testloader, isTrain=False)

    Sigma_W = compute_Sigma_W(args, model, fc_features, mu_c_dict_train, trainloader, isTrain=True)
    # Sigma_W_test_norm = compute_Sigma_W(args, model, fc_features, mu_c_dict_train, testloader, isTrain=False)
    Sigma_B = compute_Sigma_B(mu_c_dict_train, mu_G_train)

    collapse_metric = np.trace(Sigma_W @ scilin.pinv(Sigma_B)) / len(mu_c_dict_train)
    ETF_metric = compute_ETF(W)
    WH_relation_metric, H = compute_W_H_relation(W, mu_c_dict_train, mu_G_train)
    if args.bias:
        Wh_b_relation_metric = compute_Wh_b_relation(W, mu_G_train, b)
    else:
        Wh_b_relation_metric = compute_Wh_b_relation(W, mu_G_train, torch.zeros((W.shape[0], )))

    info_dict['collapse_metric'].append(collapse_metric)
    info_dict['ETF_metric'].append(ETF_metric)
    info_dict['WH_relation_metric'].append(WH_relation_metric)
    info_dict['Wh_b_relation_metric'].append(Wh_b_relation_metric)

    info_dict['W'].append((W.detach().cpu().numpy()))
    if args.bias:
        info_dict['b'].append(b.detach().cpu().numpy())
    info_dict['H'].append(H.detach().cpu().numpy())

    info_dict['mu_G_train'].append(mu_G_train.detach().cpu().numpy())
    # info_dict['mu_G_test'].append(mu_G_test.detach().cpu().numpy())

    info_dict['train_acc1'].append(train_acc1)
    info_dict['train_acc5'].append(train_acc5)
    info_dict['test_acc1'].append(test_acc1)
    info_dict['test_acc5'].append(test_acc5)

        

    print_and_save('[epoch: %d] | train top1: %.4f | train top5: %.4f | test top1: %.4f | test top5: %.4f ' %
                    (i + 1, train_acc1, train_acc5, test_acc1, test_acc5), logfile)
    
#     wandb.log({
#                    "train_acc1":train_acc1, 
#                    "train_acc5":train_acc5,
#                    "test_acc1":test_acc1,
#                    "test_acc5":test_acc5
#                 })
    
#     wandb.log({"collapse_metric":collapse_metric, 
#                    "ETF_metric":ETF_metric, 
#                    "WH_relation_metric":WH_relation_metric,
#                    "Wh_b_relation_metric":Wh_b_relation_metric
#                 })

# Set up a sweep configuration

In [10]:
sweep_config = {
    "method" : "random"
}

metric = {
    "name": "loss",
    "goal": "minimize"
}

Set up parameters to optimize

In [19]:
parameters_dict = {
    "learning_rate":{
        "values": [0.1,0.05,0.001]
    },
    "optimizer": {
        "values": ["Adam","SGD","LBFGS"]
    },
#     "width":{
#         "values": [1024,2048,4096]
#     },
#     "depth":{
#         "values": [4,8,12]
#     },
    "batch_size":{
        "values":[32,64,256,2048]
    }
}

parameters_dict.update({
    "model":{
        "values": ["resnet18"]
    },
    "epochs":{
        "value": 200
    },
    "loss":{
        "value":'CrossEntropy'
    }
    
})

# Start sweep agents

In [22]:
sweep_config["metric"]=metric
sweep_config["parameters"]=parameters_dict
sweep_id = wandb.sweep(sweep_config, project="hyper_sweep_4_opt_para")

Create sweep with ID: diryez1w
Sweep URL: https://wandb.ai/data-distillation-with-nc/hyper_sweep_4_opt_para/sweeps/diryez1w


In [40]:
def main(arg):
    
    trainloader, testloader, num_classes = make_dataset(arg.dataset, 
                                           arg.data_dir, 
                                           arg.batch_size, 
                                           SOTA=arg.SOTA)
    
    if args.model == "MLP":
        model = models.__dict__[args.model](hidden = args.width, depth = args.depth, fc_bias=args.bias, num_classes=num_classes).to(args.device)
    elif args.model == "ResNet18_adapt":
        model = ResNet18_adapt(width = args.width, num_classes=num_classes, fc_bias=args.bias).to(args.device)
    else:
        model = models.__dict__[args.model](num_classes=num_classes, fc_bias=args.bias, ETF_fc=args.ETF_fc, fixdim=args.fixdim, SOTA=args.SOTA).to(args.device)

    print('# of model parameters: ' + str(count_network_parameters(model)))
    print(type(model))
    
    criterion = make_criterion(args)
    optimizer = make_optimizer(args, model)
    
    fc_features = FCFeatures()
    model.fc.register_forward_pre_hook(fc_features)
    info_dict = {
            'collapse_metric': [],
            'ETF_metric': [],
            'WH_relation_metric': [],
            'Wh_b_relation_metric': [],
            'W': [],
            'b': [],
            'H': [],
            'mu_G_train': [],
            # 'mu_G_test': [],
            'train_acc1': [],
            'train_acc5': [],
            'test_acc1': [],
            'test_acc5': []
        }


    for epoch_id in range(args.epochs):
        trainer(args, model, trainloader, epoch_id, criterion, optimizer)
        evaluater_NC(args,model,testloader,fc_features,info_dict)
    
    

In [41]:
main(args)

Dataset: CIFAR10.
Files already downloaded and verified
Files already downloaded and verified
# of model parameters: 11181642
<class 'models.resnet.ResNet'>


RuntimeError: MPS backend out of memory (MPS allocated: 8.82 GB, other allocations: 261.88 MB, max allowed: 9.07 GB). Tried to allocate 256 bytes on shared pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [26]:
def train_n_validate(config = None, args = args):
    # Initial a new run
    with wandb.init(config=config):
        print("Initialise finished, starting now...")
        config = wandb.config
        args.lr = config["learning_rate"]
        args.optimizer = config["optimizer"]
        args.batch_size = config["batch_size"]
        main(args)

In [None]:
train_n_validate(config = sweep_config)

# Start sweep

In [None]:
wandb.agent(sweep_id,function=train_n_validate,count=5)