In [5]:
import torch

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual update step

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "SAM requires closure, please provide it."
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        # put everything on the same device, in case of model parallelism
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import time
import argparse
import random
from timm.models.vision_transformer import VisionTransformer

# --- SAM Optimizer Definition (Paste the SAM class definition from above here) ---
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        # Ensure base_optimizer is a type, not an instance
        if not isinstance(base_optimizer, type):
            raise ValueError("base_optimizer must be a class type, e.g., torch.optim.SGD")
        # Instantiate the base_optimizer with the parameters and kwargs
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                # Calculate ascent direction
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                # Restore original weights before base optimizer step
                p.data = self.state[p]["old_p"]
        # The gradients computed in the second forward/backward pass are already set
        self.base_optimizer.step()  # do the actual update step using gradients computed on perturbed weights
        if zero_grad: self.zero_grad()

    # step() function using closure is not the standard way SAM is used in loops
    # It's more common to manually call first_step and second_step in the training loop

    def _grad_norm(self):
        # Tolenrant to device placement
        shared_device = self.param_groups[0]["params"][0].device
        norms = []
        for group in self.param_groups:
             for p in group["params"]:
                 if p.grad is not None:
                     # Use p.grad.detach() to avoid modifying gradients during norm calculation if adaptive=True
                     param_grad = p.grad.detach()
                     param_norm = ((torch.abs(p.detach()) if group["adaptive"] else 1.0) * param_grad).norm(p=2)
                     norms.append(param_norm.to(shared_device))
        if not norms: # Handle case where no parameters have gradients
            return torch.tensor(0.0, device=shared_device)
        # Stack norms before calculating the final norm
        total_norm = torch.norm(torch.stack(norms), p=2)
        return total_norm

    # Overwrite zero_grad to also zero base_optimizer's gradients
    def zero_grad(self, set_to_none: bool = False):
        super(SAM, self).zero_grad(set_to_none=set_to_none)
        self.base_optimizer.zero_grad(set_to_none=set_to_none)

    # Need to handle state dict loading/saving properly for both SAM and base_optimizer
    def state_dict(self):
        # Combine SAM state and base optimizer state
        sam_state = super(SAM, self).state_dict()
        base_state = self.base_optimizer.state_dict()
        return {"sam_state": sam_state, "base_optimizer_state": base_state}

    def load_state_dict(self, state_dict):
        # Load states separately
        sam_state = state_dict["sam_state"]
        base_state = state_dict["base_optimizer_state"]
        super(SAM, self).load_state_dict(sam_state)
        self.base_optimizer.load_state_dict(base_state)
        # Ensure param_groups are synchronized after loading state
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, num=100):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, num)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
# -------------------- ResNet20 定义 --------------------
# (ResNet20 定义代码保持不变，此处省略以节省空间)
def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)
        # Don't apply weight init here if loading pretrained weights
        # self.apply(_weights_init)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride_val in strides: # Renamed variable to avoid conflict
            layers.append(block(self.in_planes, planes, stride_val)) # Use renamed variable
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        # Use adaptive average pooling for robustness
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet20():
    return ResNet(BasicBlock, [3, 3, 3], num_classes=100)
def ResNet32():
    """ ResNet-32 model configuration based on 6n+2 formula """
    # (32-2)/6 = 5
    return ResNet(BasicBlock, [5, 5, 5], num_classes=100)

def ResNet44():
    """ ResNet-44 model configuration based on 6n+2 formula """
    # (44-2)/6 = 7
    return ResNet(BasicBlock, [7, 7, 7], num_classes=100)

def ResNet56():
    """ ResNet-56 model configuration based on 6n+2 formula """
    # (56-2)/6 = 9
    return ResNet(BasicBlock, [9, 9, 9], num_classes=100)

def ResNet110():
    """ ResNet-110 model configuration based on 6n+2 formula """
    # (110-2)/6 = 18
    return ResNet(BasicBlock, [18, 18, 18], num_classes=100)

# Hardcode the values instead:
class Args:
    model = 'vit-t'
    load_path = './vit-t_cifar100_final_290.pth'
    ft_epochs = 10
    ft_lr = 0.001  # 0.001
    sam_rho = 0.01
    save_path = './vit-t_cifar100_sam_ft_290.pth'
    batch_size = 128
    data_path = './data'
args = Args()

# -------------------- 设备配置 --------------------
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# -------------------- 数据准备 --------------------
print('==> Preparing data..')
cifar100_mean = (0.5071, 0.4867, 0.4408) # CIFAR-100 specific mean
cifar100_std = (0.2675, 0.2565, 0.2761)   # CIFAR-100 specific std

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar100_mean, cifar100_std), # <--- MODIFIED to use CIFAR-100 stats
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar100_mean, cifar100_std), # <--- MODIFIED to use CIFAR-100 stats
])

trainset = torchvision.datasets.CIFAR100( # <--- MODIFIED
    root=args.data_path, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100( # <--- MODIFIED
    root=args.data_path, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# -------------------- 模型加载 --------------------
print('==> Building and loading pre-trained model..')
# net = ResNet20()
# net = net.to(device)
if args.model == 'r20':
    net = ResNet20().to(device)
elif args.model == 'r32':
    net = ResNet32().to(device)
elif args.model == 'r44':
    net = ResNet44().to(device)
elif args.model == 'r56':
    net = ResNet56().to(device)
elif args.model == 'r110':
    net = ResNet110().to(device)
elif args.model == 'vit-t':
    net = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=192, depth=12, num_heads=3).to(device)
elif args.model == 'vit-s':
    net = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=384, depth=12, num_heads=6).to(device)
elif args.model == 'vgg16':
    net = VGG('VGG16').to(device)
elif args.model == 'vgg11':
    net = VGG('VGG11').to(device)

if os.path.exists(args.load_path):
    try:
        print(f"Loading checkpoint from '{args.load_path}'")
        checkpoint = torch.load(args.load_path, map_location=device)
        # Adjust based on how the model was saved (state_dict vs full model)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
             net.load_state_dict(checkpoint['state_dict'])
        elif isinstance(checkpoint, dict) and not ('state_dict' in checkpoint): # Directly saved state_dict
             net.load_state_dict(checkpoint)
        else: # Saved the entire model object
             net = checkpoint
        print("Pre-trained model loaded successfully.")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        print("Proceeding with initialized ResNet20 (training from scratch).")
        # Apply weight initialization if not loading weights
        net.apply(_weights_init)
else:
    print(f"Checkpoint file not found at '{args.load_path}'.")
    print("Proceeding with initialized ResNet20 (training from scratch).")
    # Apply weight initialization if not loading weights
    net.apply(_weights_init)


# Optional DataParallel
# if device == 'cuda' and torch.cuda.device_count() > 1:
#     print(f"Let's use {torch.cuda.device_count()} GPUs!")
#     net = torch.nn.DataParallel(net)
#     cudnn.benchmark = True # Good if input sizes don't change

# -------------------- 损失函数 和 SAM 优化器 --------------------
criterion = nn.CrossEntropyLoss()

# Define the base optimizer (SGD) with the fine-tuning learning rate
# SAM will use this base optimizer internally
base_optimizer = torch.optim.SGD  # Pass the class, not an instance
optimizer = SAM(net.parameters(), base_optimizer, rho=args.sam_rho, adaptive=False, # Set adaptive=True if needed
                lr=args.ft_lr, momentum=0, weight_decay=5e-4)

# No learning rate scheduler needed for fixed LR fine-tuning

# -------------------- 评估函数 (Same as before) --------------------
def evaluate(loader, set_name="Test", model=net): # Pass model explicitly
    model.eval() # Set model to evaluation mode
    eval_loss = 0
    correct = 0
    total = 0
    start_time = time.time()
    with torch.no_grad(): # Disable gradient calculation
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs) # Use the passed model
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    epoch_time = time.time() - start_time
    avg_loss = eval_loss / len(loader)
    accuracy = 100. * correct / total
    print(f'{set_name.ljust(5)} | Loss: {avg_loss:.4f} | Acc: {accuracy:.3f}% ({correct}/{total}) | Time: {epoch_time:.2f}s')
    return avg_loss, accuracy


# -------------------- 初始评估 (评估加载的模型) --------------------
print("\n==> Evaluating loaded model before fine-tuning...")
initial_train_loss, initial_train_acc = evaluate(trainloader, "Train", net)
initial_test_loss, initial_test_acc = evaluate(testloader, "Test", net)
print("--------------------------------------------------")


# -------------------- SAM Fine-tuning 训练函数 --------------------
def train_sam(epoch):
    print(f'\n--- SAM Fine-tuning Epoch: {epoch+1}/{args.ft_epochs} ---')
    net.train() # Set model to training mode
    train_loss = 0
    correct = 0
    total = 0
    start_time = time.time()
    current_lr = optimizer.param_groups[0]['lr'] # Get LR (should be fixed)

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # --- SAM specific steps ---
        # 1. First forward/backward pass to compute gradients on original weights
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.first_step(zero_grad=True) # Perturbs weights and zeros grads

        # 2. Second forward/backward pass on perturbed weights
        # Ensure gradients are enabled for the second pass's backward
        with torch.enable_grad():
             criterion(net(inputs), targets).backward()
        optimizer.second_step(zero_grad=True) # Restores original weights and performs update step
        # --- End SAM steps ---

        # Accumulate loss (using loss from the first step for reporting)
        train_loss += loss.item()
        _, predicted = outputs.max(1) # Use predictions from the first step
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_time = time.time() - start_time
    epoch_loss = train_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    print(f'Train | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.3f}% ({correct}/{total}) | LR: {current_lr:.5f} | Time: {epoch_time:.2f}s')
    return epoch_loss, epoch_acc

# -------------------- Fine-tuning 主循环 --------------------
print("==> Starting SAM Fine-tuning...")
finetuning_start_time = time.time()

for epoch in range(args.ft_epochs):
    # 1. Fine-tune with SAM for one epoch
    train_loss, train_acc = train_sam(epoch)

    # 2. Evaluate on the test set after this epoch
    test_loss, test_acc = evaluate(testloader, "Test", net) # Evaluate the updated model

    # No scheduler.step() needed as LR is fixed

total_finetuning_time = time.time() - finetuning_start_time
print(f"\n==> Finished SAM Fine-tuning in {total_finetuning_time:.2f} seconds.")


# -------------------- 保存最终 Fine-tuned 模型 --------------------
print(f'==> Saving final fine-tuned model to {args.save_path}')
save_dir = os.path.dirname(args.save_path)
if save_dir and not os.path.exists(save_dir): # Check if save_dir is not empty
    os.makedirs(save_dir)
# Save only the model state_dict is usually preferred
torch.save(net.state_dict(), args.save_path)
# If you need to save optimizer state as well (e.g., to resume SAM training):
# torch.save({
#     'epoch': args.ft_epochs, # Or the actual last epoch number
#     'model_state_dict': net.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(), # Save SAM state
# }, args.save_path)
print("Final fine-tuned model saved.")

# -------------------- 评估最终 Fine-tuned 模型 --------------------
print("\n==> Evaluating final fine-tuned model (after {} epochs)...".format(args.ft_epochs))
print("--- Final Training Set Evaluation ---")
final_train_loss, final_train_acc = evaluate(trainloader, "Train", net)
print("--- Final Test Set Evaluation ---")
final_test_loss, final_test_acc = evaluate(testloader, "Test", net)

print("\n===== Initial Model Performance =====")
print(f"Initial Training Loss: {initial_train_loss:.4f}")
print(f"Initial Training Acc:  {initial_train_acc:.3f}%")
print(f"Initial Test Loss:     {initial_test_loss:.4f}")
print(f"Initial Test Acc:      {initial_test_acc:.3f}%")
print("====================================")

print("\n===== Final Fine-tuned Model Performance =====")
print(f"Final Training Loss: {final_train_loss:.4f}")
print(f"Final Training Acc:  {final_train_acc:.3f}%")
print(f"Final Test Loss:     {final_test_loss:.4f}")
print(f"Final Test Acc:      {final_test_acc:.3f}%")
print("==========================================")

Using device: cuda:1
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building and loading pre-trained model..
Loading checkpoint from './vit-t_cifar100_final_290.pth'
Pre-trained model loaded successfully.

==> Evaluating loaded model before fine-tuning...


  checkpoint = torch.load(args.load_path, map_location=device)


Train | Loss: 1.2419 | Acc: 64.724% (32362/50000) | Time: 11.57s
Test  | Loss: 1.7478 | Acc: 53.700% (5370/10000) | Time: 1.64s
--------------------------------------------------
==> Starting SAM Fine-tuning...

--- SAM Fine-tuning Epoch: 1/10 ---
Train | Loss: 1.2518 | Acc: 64.312% (32156/50000) | LR: 0.00100 | Time: 38.82s
Test  | Loss: 1.7473 | Acc: 53.820% (5382/10000) | Time: 1.69s

--- SAM Fine-tuning Epoch: 2/10 ---
Train | Loss: 1.2435 | Acc: 64.676% (32338/50000) | LR: 0.00100 | Time: 38.56s
Test  | Loss: 1.7487 | Acc: 53.810% (5381/10000) | Time: 1.63s

--- SAM Fine-tuning Epoch: 3/10 ---
Train | Loss: 1.2463 | Acc: 64.198% (32099/50000) | LR: 0.00100 | Time: 39.00s
Test  | Loss: 1.7473 | Acc: 53.810% (5381/10000) | Time: 1.63s

--- SAM Fine-tuning Epoch: 4/10 ---
Train | Loss: 1.2472 | Acc: 64.446% (32223/50000) | LR: 0.00100 | Time: 39.03s
Test  | Loss: 1.7473 | Acc: 53.650% (5365/10000) | Time: 1.61s

--- SAM Fine-tuning Epoch: 5/10 ---
Train | Loss: 1.2497 | Acc: 64.550% 