In [1]:
import os
import sys

import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from tqdm import trange
import wandb
from torchsummary import summary

sys.path.append(os.path.dirname(os.getcwd()))
from pyfed.models.cnns import SimpleCNN
from pyfed.metrics.hessian import Hessian
from pyfed.metrics.basic import ClassificationMetrics
from pyfed.utils import set_random_state

In [6]:
wandb.run.finish()

In [2]:
project = 'hessian_distill'
run_sweep = False

# if wannna run sweep
sweep_config = {
    'method': 'bayes',
    'name': 'sweep',
    'metric': {
        'goal': 'maximize', 
        'name': 'test.accuracy'
    },
    'early_terminate':{
        'type':'hyperband',
        'min_iter':2,
        'eta':2,
    },
    'parameters': {
        'batch_size': {'values': [32, 64]},
        'lr': {'max': 1e-1, 'min': 1e-5, 'distribution':'log_uniform_values'},
        'weight_decay': {'max': 1e-3, 'min': 1e-6, 'distribution':'log_uniform_values'},
        'momentum': {'max': 0.999, 'min': 0.7, 'distribution':'log_uniform_values'},
     },
}

In [3]:
configs = {
    'seed':42,
    'device':'cuda:0',
    'pretrained':True,
    'ckpt_every':10,
    'ckpt_path':os.path.join('/home/jahn/ckpt', project),
    'arch':efficientnet_v2_l,
    'epoch':30,
    'batch_size':64,
    'lr': 1e-3,
    'weight_decay':1e-5,
    'momentum':0.9,
    'amp':True,
}
assert os.path.isdir(configs['ckpt_path']), f'Make directory for saving checkpoints at {configs["ckpt_path"]}'

mean, std = [0.49139968, 0.48215827 ,0.44653124], [0.24703233, 0.24348505, 0.26158768]
trainset = CIFAR10(
    '~/data', train=True,
    transform=T.Compose([
        T.Resize((224,224)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(90),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)])
)
teacher_idx, student_idx = train_test_split(range(len(trainset)), test_size=0.2, random_state=configs['seed'], stratify=trainset.targets)
teacherset = Subset(trainset, teacher_idx)
testset = CIFAR10(
    '~/data', train=False,
    transform=T.Compose([
        T.Resize((224,224)),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)])
)

In [4]:
device = torch.device(configs['device'])
metrics = ClassificationMetrics(n_classes=10, average=None)

In [5]:
def train():
    # Initial setups
    set_random_state(configs['seed'])
    net = configs['arch'](weights=EfficientNet_V2_L_Weights.DEFAULT if configs['pretrained'] else None)
    # Replace fc layer
    if list(net.modules())[-1].out_features != len(trainset.classes):
        fc_name = list(net.named_children())[-1][0]
        assert fc_name in ['fc', 'classifier'], 'This case is not considered. Reimplement this part'
        outdim = len(trainset.classes)
        if fc_name == 'fc':
            setattr(net, fc_name, nn.Linear(net.fc.in_features, outdim, bias=True))
        elif fc_name == 'classifier':
            net.classifier[1] = nn.Linear(net.classifier[1].in_features, outdim, bias=True)
    init_state = net.state_dict()
    net.to(device)
    
    if run_sweep:
        run = wandb.init(config=configs)
        run.config.update(wandb.config)
    else:
        group = 'train'
        name = f'{net.__class__.__name__}_teacher'
        tags = [net.__class__.__name__, trainset.__class__.__name__]
        run = wandb.init(group=group, job_type='train', project=project, tags=tags, name=name, config=configs)
        
    trainloader = DataLoader(teacherset, batch_size=run.config['batch_size'], shuffle=True, num_workers=4)
    testloader = DataLoader(testset, batch_size=run.config['batch_size'], shuffle=False, num_workers=4)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(net.parameters(), lr=run.config['lr'], momentum=run.config['momentum'], weight_decay=run.config['weight_decay'])
    scaler = torch.cuda.amp.GradScaler(enabled=run.config['amp'])

    
    # training
    for e in (pbar:=trange(1,run.config['epoch']+1)):
        net.train()
        metrics.reset()
        for b, (data, target) in enumerate(trainloader):
            pbar.set_description(f'[train]|[{b:>4}/{len(trainloader):>4}]')
            data, target = data.to(device), target.to(device)
            with torch.autocast(device_type=device.type, enabled=run.config['amp']):
                output = net(data)
                loss = criterion(output, target)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # Calc metrics
            targetcpu = target.cpu().numpy()
            predcpu = output.argmax(1).cpu().detach().numpy()
            metrics.update(targetcpu, predcpu)
        run.log({'train':{
            'accuracy':metrics.accuracy(),
            'precision':{trainset.classes[k]:v for k, v in enumerate(metrics.precision())},
            'recall':{trainset.classes[k]:v for k, v in enumerate(metrics.recall())},
            'f1':{trainset.classes[k]:v for k, v in enumerate(metrics.f1())},
        }}, commit=False)
        net.eval()
        metrics.reset()
        with torch.no_grad():
            for b, (data, target) in enumerate(testloader):
                pbar.set_description(f'[test]|[{b:>4}/{len(testloader):>4}]')
                data, target = data.to(device), target.to(device)
                with torch.autocast(device_type=device.type, enabled=run.config['amp']):
                    output = net(data)
                # Calc metrics
                targetcpu = target.cpu().numpy()
                predcpu = output.argmax(1).cpu().detach().numpy()
                metrics.update(targetcpu, predcpu)
        run.log(dict({'test':{
            'accuracy':metrics.accuracy(),
            'precision':{trainset.classes[k]:v for k, v in enumerate(metrics.precision())},
            'recall':{trainset.classes[k]:v for k, v in enumerate(metrics.recall())},
            'f1':{trainset.classes[k]:v for k, v in enumerate(metrics.f1())},
        }}, **{'epoch': e}))
        # Checkpointing
        if e%run.config['ckpt_every']==0:
            formatter = f'>0{len(str(run.config["epoch"]))}'
            torch.save(net.state_dict(), os.path.join(run.config['ckpt_path'], f'{name}_epoch{e:{formatter}}.pth'))
    del net
    torch.cuda.empty_cache()
    run.finish()

In [None]:
run_sweep = True
sweep_id = '' or wandb.sweep(sweep_config, project=project)
wandb.agent(sweep_id=sweep_id, function=train, count=100)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: t18d5uyh
Sweep URL: https://wandb.ai/iislab-official/hessian_distill/sweeps/t18d5uyh


[34m[1mwandb[0m: Agent Starting Run: uyz2gm0v with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	lr: 0.02440151032753152
[34m[1mwandb[0m: 	momentum: 0.8075485670158714
[34m[1mwandb[0m: 	weight_decay: 4.1057289263300016e-05
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdev-jahn[0m ([33miislab-official[0m). Use [1m`wandb login --relogin`[0m to force relogin


[train]|[  54/1250]:  13%|███████                                              | 4/30 [12:09<1:18:05, 180.21s/it]