# Import

In [1]:
import argparse
import torch
from statistics import mean
import csv
from torchvision.models import efficientnet_v2_m
from torchvision.models import EfficientNet_V2_M_Weights
from torchvision.models import vit_b_16
from torchvision.models import ViT_B_16_Weights

from model.wide_res_net_cifar import WideResNet_cifar
from model.wide_res_net_fashionmnist import WideResNet_fashionmnist
# from model.wide_res_net_food101 import WideResNet_food101
from model.PyramidNet import PyramidNet

from model.smooth_cross_entropy import smooth_crossentropy

from data_cifar100.cifar import Cifar100
from data_cifar100_224.cifar import Cifar100_224
from data_cifar10.cifar import Cifar10
from data_fashionmnist.fashionmnist import fashionmnist
from data_food101.food101 import Food101

from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from adversarial_cross_entropy import AdaptiveAdversrialCrossEntropy


# SAM

In [2]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho, adaptive, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, use_grad_norm, zero_grad=False):
        grad_norm = self._grad_norm()
        ew_norm_squared_total = 0.0  # Initialize the total squared norm of e_w
        for group in self.param_groups:
            if use_grad_norm:
                scale = group["rho"] / (grad_norm + 1e-12)
            else:
                scale = torch.tensor(group["rho"]).to(self.param_groups[0]["params"][0].device)

            for p in group["params"]: #note that 'p' (normally) is parameters vector for a certain layer, not in dividual parameter
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                # p.add_(e_w)  # climb to the local maximum "w + e(w)" #SAM
                p.sub_(e_w)  # descend to the local minimum "w - e(w)" #AACE
                
                ew_norm_squared_total += e_w.norm(p=2).pow(2)  # Accumulate the squared norm of e_w
        
        ew_norm = (ew_norm_squared_total ** 0.5).item()  # Calculate the total norm of e_w
        
        if zero_grad: self.zero_grad()
        return grad_norm.item(), ew_norm

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

# Training Script

In [3]:
def sam_train(rho, iteration=0):
    """
    When select experiment:
    - WideResNet: please select one that corresdoning wiht the data, since some layer in the model need to be modified
    - PyramidNet: curtainly work only with Cifar
    """
    experiment = "PyramidNet" #"WideResNet_cifar100", "WideResNet_cifar10", "WideResNet_fashinmnist", "WideResNet_food101", "PyramidNet", "EfficientNetV2_m"
    data = "Cifar100" #"Cifar100", "Cifar100_224", "Cifar10", "fashionmnist", "food101"
    #DONT FORGET TO CHECK 'num_classes'

    # define parameters
    gpu = "cuda:3"
    threads = 36
    # rho = 0.05
    # rho_increae = False
    adaptive = False
    label_smoothing = 0
    use_grad_norm = False
    mode = 'ada' #'ada', 'rand', 'const'
    result_dir = "/home/tratchatorn/SAM/sam/example/result_repeat_" + experiment + "_" + data
    file_name = f"repeat_AACE_{mode}_GradNorm:{use_grad_norm}_rho:{rho}_{iteration}"
    optimizer_save_path = result_dir + "/" + file_name + "_opt.pth"
    model_save_path = result_dir + "/" + file_name + "_model.pth"
    csv_path = result_dir + "/" + file_name + ".csv"

    
    header = ["epoch", "lr", "avg_ce_loss", "avg_aace_loss", "avg_grad_norm", "avg_perturbation", "train_loss", "train_acc", "val_loss", "val_acc"]
    with open(csv_path, 'a') as file:
        writer = csv.writer(file)
        writer.writerow(header)
    
    initialize(seed=42)
    device = torch.device(gpu)

    log = Log(log_each=10)
    
    base_optimizer = torch.optim.SGD
    
    if experiment == "WideResNet_cifar100":
        batch_size = 256
        learning_rate = 0.1
        momentum = 0.9
        weight_decay = 0.0005
        epochs = 200
        depth = 28
        width_factor = 10
        dropout = 0
        model = WideResNet_cifar(depth, width_factor, dropout, in_channels=3, labels=100).to(device)
        optimizer = SAM(model.parameters(), base_optimizer, rho=rho, adaptive=adaptive, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "WideResNet_cifar10":
        batch_size = 256
        learning_rate = 0.1
        momentum = 0.9
        weight_decay = 0.0005
        epochs = 200
        depth = 28
        width_factor = 10
        dropout = 0
        model = WideResNet_cifar(depth, width_factor, dropout, in_channels=3, labels=10).to(device)
        optimizer = SAM(model.parameters(), base_optimizer, rho=rho, adaptive=adaptive, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "PyramidNet":
        batch_size = 64
        learning_rate = 0.1
        momentum = 0.9
        weight_decay = 0.0005
        epochs = 200
        depth = 272
        alpha = 200
        num_classes = 100
        bottleneck = True
        _data = "cifar100" #for "food101" use dataset="cifar100"
#         _data = "fashionmnist"
        model = PyramidNet(dataset=_data, depth=depth, alpha=alpha, num_classes=num_classes).to(device)
        optimizer = SAM(model.parameters(), base_optimizer, rho=rho, adaptive=adaptive, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "EfficientNetV2_m":
        """to overwrite dropout fix EfficientNet.py (anaconda3/envs/Jin_AACE/lib/python3.8/site-packages/torchvision/models/efficientnet.py)
        according to https://github.com/pytorch/vision/commit/5785e2b05cdffeb39678914b8308a260e7e757db 
        or
        update torchvision to 0.17 or newer"""
        batch_size = 32
        learning_rate = 0.001
        momentum = 0.9
        weight_decay = 0
        epochs = 50
        dropout = 0
        num_classes = 100
        pre_trained_weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 # EfficientNet_V2_M_Weights.IMAGENET1K_V1, None
        model = efficientnet_v2_m(weights=pre_trained_weights, dropout=dropout) #use pre-trained weights
        model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) #change last dense layer to be consistent with training data
        model = model.to(device)
        optimizer = SAM(model.parameters(), base_optimizer, rho=rho, adaptive=adaptive, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)
#         scheduler = StepLR(optimizer, learning_rate, epochs)
    elif experiment == "vit_b_16":
        batch_size = 32
        learning_rate = 0.001
        momentum = 0.9
        weight_decay = 0
        epochs = 100
        dropout = 0
        num_classes = 100
        image_size = 224
        pre_trained_weights = ViT_B_16_Weights.IMAGENET1K_V1 # weights=ViT_B_16_Weights.IMAGENET1K_V1, None
        model = vit_b_16(weights=pre_trained_weights, image_size=image_size)
        model.heads.head = torch.nn.Linear(model.heads.head.in_features, num_classes) #change last dense layer to be consistent with training data
        model = model.to(device)
        optimizer = SAM(model.parameters(), base_optimizer, rho=rho, adaptive=adaptive, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)
        
        
    if data == "Cifar100":
        dataset = Cifar100(batch_size, threads)
    elif data == "Cifar100_224":
        dataset = Cifar100_224(batch_size, threads)
    elif data == "Cifar10":
        dataset = Cifar10(batch_size, threads)
    elif data == "fashionmnist":
        dataset = fashionmnist(batch_size, threads)
    elif data == "food101":
        dataset = Food101(batch_size, threads)
        
    for epoch in range(epochs):
#         if rho_increae == True:
#             current_rho = (2*rho) * (epoch/epochs)
#             optimizer = SAM(model.parameters(), base_optimizer, rho=current_rho, adaptive=adaptive, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        result_list = []
        model.train()
        log.train(len_dataset=len(dataset.train))
        
        grad_norm_list = []
        ew_norm_list = []
        
        sum_ce_loss = 0
        sum_perturb_loss = 0
        sum_loss_num = 0
        
        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            
            ce_loss = smooth_crossentropy(predictions, targets, smoothing=label_smoothing) #ce_loss is not used, just calculate it for observation purpose
            sum_ce_loss += ce_loss.sum().item()
            
            loss_function = AdaptiveAdversrialCrossEntropy()
            perturb_loss = loss_function(predictions, targets)
            perturb_loss.mean().backward()
            sum_perturb_loss += perturb_loss.sum().item()
    
            sum_loss_num += perturb_loss.size(0)
            
            grad_norm, ew_norm = optimizer.first_step(use_grad_norm=use_grad_norm, zero_grad=True)
            ew_norm_list.append(ew_norm)
            grad_norm_list.append(grad_norm)
            
            # second forward-backward step
            disable_running_stats(model)
            loss = smooth_crossentropy(model(inputs), targets, smoothing=label_smoothing)
            loss.mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                if experiment == "EfficientNetV2_m" or experiment == "vit_b_16":
#                 if experiment == "vit_b_16":
                    log(model, loss.cpu(), correct.cpu(), optimizer.param_groups[0]['lr'])
                    scheduler.step()
                else:
                    log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                    scheduler(epoch)
        
        avg_ce_loss = sum_ce_loss/sum_loss_num
        avg_perturb_loss = sum_perturb_loss/sum_loss_num
        avg_ew_norm = mean(ew_norm_list)
        avg_grad_norm = mean(grad_norm_list)
        
        model.eval()
        train_loss, train_acc, lr = log.output_0()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets, smoothing=label_smoothing)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())
        
        val_loss, val_acc = log.output_1()
        log.flush()

        result_list.extend([epoch, lr, avg_ce_loss, avg_perturb_loss, avg_grad_norm, avg_ew_norm, train_loss, train_acc, val_loss, val_acc])
        with open(csv_path, 'a') as file:
            writer = csv.writer(file)
            writer.writerow(result_list)
    
    torch.save(model.state_dict(), model_save_path)
    torch.save(optimizer.state_dict(), optimizer_save_path)
    
#     torch.cuda.empty_cache()

# sam_train()

In [4]:
rho_list = [0.2]

for i, rho in enumerate(rho_list):
    sam_train(rho, iteration=4)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨
┃           0  ┃      3.9092  │      9.18 %  ┃   1.000e-01  │   12:35 min  ┃┈███████████████████████████┈┨      3.6182  │     13.41 %  ┃
      3.6182  │     13.41 %  ┃
┃           1  ┃      3.3521  │     18.33 %  ┃   1.000e-01  │   12:33 min  ┃┈███████████████████████████┈┨      3.1234  │     21.81 %  ┃
      3.1234  │     21.81 %  ┃
┃           2  ┃      2.9342  │     26.45 %  ┃   1.000e-01  │   12:35 min  ┃┈███████████████████████████┈┨      2.7005  │ 