In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import time
import argparse
import random
import copy # For deep copying models and states
import math # For infinity
from timm.models.vision_transformer import VisionTransformer

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, num=10):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, num)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


# --- ResNet20 Definition (Same as before, omitted for brevity) ---
def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
class LambdaLayer(nn.Module):
    def __init__(self, lambd): super().__init__(); self.lambd = lambd
    def forward(self, x): return self.lambd(x)
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, option='A'):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A': self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B': self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes))
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out)); out += self.shortcut(x); out = F.relu(out)
        return out
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1); layers = []
        for stride_val in strides: layers.append(block(self.in_planes, planes, stride_val)); self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x))); out = self.layer1(out); out = self.layer2(out); out = self.layer3(out)
        out = F.adaptive_avg_pool2d(out, (1, 1)); out = out.view(out.size(0), -1); out = self.linear(out)
        return out
def ResNet20(): return ResNet(BasicBlock, [3, 3, 3], num_classes=10)

def ResNet32():
    """ ResNet-32 model configuration based on 6n+2 formula """
    # (32-2)/6 = 5
    return ResNet(BasicBlock, [5, 5, 5], num_classes=10)

def ResNet44():
    """ ResNet-44 model configuration based on 6n+2 formula """
    # (44-2)/6 = 7
    return ResNet(BasicBlock, [7, 7, 7], num_classes=10)

def ResNet56():
    """ ResNet-56 model configuration based on 6n+2 formula """
    # (56-2)/6 = 9
    return ResNet(BasicBlock, [9, 9, 9], num_classes=10)

def ResNet110():
    """ ResNet-110 model configuration based on 6n+2 formula """
    # (110-2)/6 = 18
    return ResNet(BasicBlock, [18, 18, 18], num_classes=10)



# -------------------- CNOGNP & Script Parameters --------------------
parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Fine-tuning with CNOGNP on Weights')
# --- CNOGNP Specific ---
parser.add_argument('--lambda_gnp', default=0.1, type=float, help='Gradient norm penalty coefficient (λ)')
# --- Parameters from CNO (Mapping from image) ---
parser.add_argument('--num_particles', default=2, type=int, help='Number of particles (N)')
parser.add_argument('--cnognp_epochs', default=10, type=int, help='Number of CNOGNP iterations (κ_max)') # Renamed epoch arg
parser.add_argument('--w', default=1, type=float, help='Inertia weight (ω)')
parser.add_argument('--c1', default=0.00001, type=float, help='Cognitive learning factor (c1)')
parser.add_argument('--c2', default=0.00001, type=float, help='Social learning factor (c2)')
parser.add_argument('--eta', default=0.001, type=float, help='Scale factor / Learning rate for inner SGD step (η)')
# --- Parameters consistent with others ---
parser.add_argument('--initial_noise_level', default=0.0001, type=float, help='Std deviation of noise added to initial particle weights')
parser.add_argument('--inner_sgd_momentum', default=0, type=float, help='Momentum for the inner SGD step')
parser.add_argument('--inner_sgd_wd', default=5e-4, type=float, help='Weight decay for the inner SGD step')
parser.add_argument('--model', default='r110', type=str)
parser.add_argument('--load_path', default='./resnet110_cifar10_final.pth', type=str, help='Path to load the pre-trained model')
parser.add_argument('--save_path', default='./resnet110_cifar10_cnognp_ft.pth', type=str, help='Path to save the best model found by CNOGNP')
parser.add_argument('--batch_size', default=128, type=int, help='Batch size for SGD training and evaluation')
parser.add_argument('--data_path', default='./data', type=str, help='Path to dataset')
# --- Use parse_known_args() for Jupyter compatibility ---
args, unknown = parser.parse_known_args()



# -------------------- Device Configuration --------------------
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    cudnn.benchmark = True

# -------------------- 数据准备 --------------------
print('==> Preparing data..')
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

trainset = torchvision.datasets.CIFAR10(
    root=args.data_path, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root=args.data_path, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# -------------------- Loss Function --------------------
criterion = nn.CrossEntropyLoss()

# -------------------- Helper Functions --------------------

# Standard evaluation function (for final results)
def evaluate(loader, model, set_name="Test"):
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    avg_loss = eval_loss / len(loader)
    accuracy = 100. * correct / total
    print(f'{set_name.ljust(5)} Eval | Loss: {avg_loss:.4f} | Acc: {accuracy:.3f}% ({correct}/{total})')
    return avg_loss, accuracy

# Function for CNO Line 7: Train ONE SGD epoch and return the *new state* and the loss
def train_one_sgd_epoch_and_get_state(initial_state_dict, train_loader, criterion, device, lr, momentum, weight_decay, model_name):
    if args.model == 'r20':
        model = ResNet20().to(device)
    elif args.model == 'r32':
        model = ResNet32().to(device)
    elif args.model == 'r44':
        model = ResNet44().to(device)
    elif args.model == 'r56':
        model = ResNet56().to(device)
    elif args.model == 'r110':
        model = ResNet110().to(device)
    elif args.model == 'vit-t':
        model = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=192, depth=12, num_heads=3).to(device)
    elif args.model == 'vit-s':
        model = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=384, depth=12, num_heads=6).to(device)
    elif args.model == 'vgg16':
        model = VGG('VGG16').to(device)
    elif args.model == 'vgg11':
        model = VGG('VGG11').to(device)
    model.load_state_dict(copy.deepcopy(initial_state_dict))
    model.train()
    train_loss = 0
    total_grad_norm = 0 # Accumulate norm per batch
    num_batches = 0
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    start_time = time.time()
    print(f"      Starting 1-epoch SGD (CNOGNP Line 7, η={lr})... ", end="")
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        batch_grad_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.detach().data.norm(2)
                batch_grad_norm += param_norm.item() ** 2
        batch_grad_norm = batch_grad_norm ** 0.5
        total_grad_norm += batch_grad_norm
    avg_loss = train_loss / len(train_loader)
    avg_grad_norm = total_grad_norm / len(train_loader) # Average the norm calculated per batch
    epoch_time = time.time() - start_time
    print(f"Done. Avg Loss during SGD: {avg_loss:.4f} | Time: {epoch_time:.2f}s")
    return copy.deepcopy(model.state_dict()), avg_loss, avg_grad_norm

# Function for CNOGNP Line 10: Evaluate fitness (Loss + λ*||g||₂) on train set
def evaluate_fitness_loss_and_grad_norm(avg_loss, avg_grad_norm, lambda_gnp):
    # CNOGNP Fitness
    fitness = avg_loss + lambda_gnp * avg_grad_norm
    print(f"Done. Fitness: {fitness:.4f} (Avg Loss: {avg_loss:.4f}, Avg Grad Norm: {avg_grad_norm:.4f})")
    return fitness

# --- Other helper functions (add_noise, initialize_velocity, update_cno_velocity, update_particle_position) ---
# --- remain EXACTLY the same as in the CNO implementation ---
def add_noise_to_model(model, noise_level, device):
    print('add noise')
    with torch.no_grad():
        for param in model.parameters(): param.add_(torch.randn_like(param) * noise_level)
    return model
def initialize_velocity(model):
    velocity = {}
    with torch.no_grad():
      for name, param in model.named_parameters():
          if param.requires_grad: velocity[name] = torch.zeros_like(param)
    return velocity
def update_cno_velocity(velocity_dict, z_bar_i_state, pbest_state, gbest_state, c1, c2, device):
    with torch.no_grad():
        for name, param_vel in velocity_dict.items():
            if name not in z_i_state: continue
            r1 = random.random()
            r2 = random.random()

            # Ensure all tensors are on the correct device
            z_bar_i_param = z_bar_i_state[name].to(device)
            pbest_param = pbest_state[name].to(device)
            gbest_param = gbest_state[name].to(device)
            current_vel = param_vel.to(device)

            # CNO Velocity Update (Line 8)
            cognitive_term = c1 * r1 * (pbest_param - z_bar_i_param)
            social_term = c2 * r2 * (gbest_param - z_bar_i_param)

            new_vel = cognitive_term + social_term
            velocity_dict[name].copy_(new_vel) # Update velocity in place
def update_particle_position(model_to_update, z_bar_i_state, velocity_dict, device):
    new_state = copy.deepcopy(z_bar_i_state)
    with torch.no_grad():
        for name, param in new_state.items():
             if name in velocity_dict: param.add_(velocity_dict[name].to(device))
    model_to_update.load_state_dict(new_state)

# -------------------- Load Pre-trained Model --------------------
print('==> Loading pre-trained model...')
# initial_model = ResNet20().to(device)
if args.model == 'r20':
    initial_model = ResNet20().to(device)
elif args.model == 'r32':
    initial_model = ResNet32().to(device)
elif args.model == 'r44':
    initial_model = ResNet44().to(device)
elif args.model == 'r56':
    initial_model = ResNet56().to(device)
elif args.model == 'r110':
    initial_model = ResNet110().to(device)
elif args.model == 'vit-t':
    initial_model = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=192, depth=12, num_heads=3).to(device)
elif args.model == 'vit-s':
    initial_model = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=384, depth=12, num_heads=6).to(device)
elif args.model == 'vgg16':
    initial_model = VGG('VGG16').to(device)
elif args.model == 'vgg11':
    initial_model = VGG('VGG11').to(device)

# (Loading logic remains the same as CNO)
if os.path.exists(args.load_path):
    try:
        checkpoint = torch.load(args.load_path, map_location=device)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: initial_model.load_state_dict(checkpoint['state_dict'])
        elif isinstance(checkpoint, dict): initial_model.load_state_dict(checkpoint)
        else: initial_model = checkpoint
        print(f"Loaded pre-trained weights from '{args.load_path}'")
    except Exception as e: print(f"Error loading checkpoint: {e}. Exiting."); exit()
else: print(f"Pre-trained model file not found at '{args.load_path}'. Exiting."); exit()

# Evaluate the loaded model once (using standard evaluate)
print("\n==> Evaluating loaded pre-trained model:")
initial_test_loss, initial_test_acc = evaluate(testloader, initial_model, "Test")
initial_train_loss, initial_train_acc = evaluate(trainloader, initial_model, "Train")


# -------------------- CNOGNP Initialization --------------------
print(f"\n==> Initializing {args.num_particles} CNOGNP particles...")
particles = []
gbest_state_dict = None
gbest_fitness = float('inf') # Use math.inf for clarity

for i in range(args.num_particles):
    print(f"  Initializing particle {i+1}/{args.num_particles}...")
    particle_model = copy.deepcopy(initial_model).to(device)
    if i > 0 or args.num_particles == 1:
         particle_model = add_noise_to_model(particle_model, args.initial_noise_level, device)
    velocity = initialize_velocity(particle_model)
    pbest_state_dict = copy.deepcopy(particle_model.state_dict())
    pbest_fitness = float('inf')
    particles.append({
        'id': i, 'model': particle_model, 'velocity': velocity,
        'pbest_state_dict': pbest_state_dict, 'pbest_fitness': pbest_fitness,
        'current_fitness': float('inf')
    })

# --- Initial Fitness Evaluation (using CNOGNP fitness function) ---
print("\n==> Performing initial fitness evaluation (using CNOGNP fitness)...")
current_epoch_best_fitness = float('inf')
current_epoch_best_particle_idx = -1
for i, particle in enumerate(particles):
    print(f"  Evaluating initial fitness for particle {i+1}/{args.num_particles}:")
    # Use the NEW fitness function
    fitness = evaluate_fitness_loss_and_grad_norm(
        999, 999, args.lambda_gnp
    )
    particle['current_fitness'] = fitness
    particle['pbest_fitness'] = fitness # Initial pbest fitness

    if fitness < current_epoch_best_fitness:
        current_epoch_best_fitness = fitness
        current_epoch_best_particle_idx = i

# # Update global best (gbest) based on the initial evaluation
# if current_epoch_best_particle_idx != -1 :
#      initial_best_particle = particles[current_epoch_best_particle_idx]
#      print(f"\nInitial Global Best Fitness (particle {current_epoch_best_particle_idx+1}): {current_epoch_best_fitness:.4f}")
#      gbest_fitness = current_epoch_best_fitness
#      gbest_state_dict = copy.deepcopy(initial_best_particle['pbest_state_dict'])
# else:
#      print("\nWarning: No valid fitness found in initial evaluation.")
#      # Fallback: use the originally loaded model as gbest
#      gbest_fitness = evaluate_fitness_loss_and_grad_norm(avg_loss, avg_grad_norm, args.lambda_gnp)
#      gbest_state_dict = copy.deepcopy(initial_model.state_dict())
#      print(f"Using loaded model as initial gbest (Fitness: {gbest_fitness:.4f})")


# -------------------- CNOGNP Main Loop --------------------
print(f"\n==> Starting CNOGNP Fine-tuning for {args.cnognp_epochs} epochs...")
cnognp_start_time = time.time()

# Use args.cnognp_epochs here
for cnognp_epoch in range(args.cnognp_epochs):
    print(f"\n--- CNOGNP Epoch {cnognp_epoch + 1}/{args.cnognp_epochs} ---")
    epoch_start_time = time.time()
    current_epoch_best_fitness = float('inf')
    current_epoch_best_particle_idx = -1

    for i, particle in enumerate(particles):
        print(f"  Processing Particle {i+1}/{args.num_particles}:")
        z_i_state = copy.deepcopy(particle['model'].state_dict())



        # --- CNOGNP Line 7: Perform SGD step ---
        z_bar_i_state, sgd_run_loss, avg_grad_norm = train_one_sgd_epoch_and_get_state(
            z_i_state, trainloader, criterion, device,
            args.eta, args.inner_sgd_momentum, args.inner_sgd_wd, args.model
        )
        new_state = copy.deepcopy(z_bar_i_state)
        particle['model'].load_state_dict(new_state)
        current_fitness = evaluate_fitness_loss_and_grad_norm(
            sgd_run_loss, avg_grad_norm, args.lambda_gnp
        )
        particle['current_fitness'] = current_fitness

        # --- CNOGNP Lines 11-13: Update PBest ---
        if current_fitness < particle['pbest_fitness']:
            print(f"      New pbest for particle {i+1}: {current_fitness:.4f} (was {particle['pbest_fitness']:.4f})")
            particle['pbest_fitness'] = current_fitness
            particle['pbest_state_dict'] = copy.deepcopy(particle['model'].state_dict())
        else:
            print(f"      Fitness {current_fitness:.4f} not better than pbest {particle['pbest_fitness']:.4f}")

        if current_fitness < current_epoch_best_fitness:
             current_epoch_best_fitness = current_fitness
             current_epoch_best_particle_idx = i

        # --- CNOGNP Lines 14-16: Update GBest ---
        print("  Updating gbest...")
        if current_epoch_best_particle_idx != -1 and current_epoch_best_fitness < gbest_fitness:
            print(f"    New Global Best! Fitness: {current_epoch_best_fitness:.4f} (was {gbest_fitness:.4f}) from particle {current_epoch_best_particle_idx+1}'s pbest")
            gbest_fitness = current_epoch_best_fitness
            gbest_state_dict = copy.deepcopy(particles[current_epoch_best_particle_idx]['pbest_state_dict'])
        else:
            print(f"    No new gbest found this epoch. Best this epoch: {current_epoch_best_fitness:.4f}, Current gbest: {gbest_fitness:.4f}")

        # --- CNOGNP Line 8: Update Velocity ---
        update_cno_velocity(
            particle['velocity'], z_bar_i_state,
            particle['pbest_state_dict'], gbest_state_dict,
            args.c1, args.c2, device
        )

        # --- CNOGNP Line 9: Update Position ---
        update_particle_position(
             particle['model'], z_bar_i_state, particle['velocity'], device
        )

        # --- CNOGNP Line 10: Evaluate Fitness (Loss + λ*||g||₂) ---
        # Use the NEW fitness function
    epoch_time = time.time() - epoch_start_time
    print(f"--- CNOGNP Epoch {cnognp_epoch + 1} finished. Time: {epoch_time:.2f}s ---")
        

    


total_cnognp_time = time.time() - cnognp_start_time
print(f"\n==> Finished CNOGNP Fine-tuning in {total_cnognp_time:.2f} seconds ({total_cnognp_time/3600:.2f} hours).")



  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Loading pre-trained model...
Loaded pre-trained weights from './resnet110_cifar10_final.pth'

==> Evaluating loaded pre-trained model:


  checkpoint = torch.load(args.load_path, map_location=device)


Test  Eval | Loss: 0.2423 | Acc: 94.020% (9402/10000)
Train Eval | Loss: 0.0025 | Acc: 99.988% (49994/50000)

==> Initializing 2 CNOGNP particles...
  Initializing particle 1/2...
  Initializing particle 2/2...
add noise

==> Performing initial fitness evaluation (using CNOGNP fitness)...
  Evaluating initial fitness for particle 1/2:
Done. Fitness: 1098.9000 (Avg Loss: 999.0000, Avg Grad Norm: 999.0000)
  Evaluating initial fitness for particle 2/2:
Done. Fitness: 1098.9000 (Avg Loss: 999.0000, Avg Grad Norm: 999.0000)

==> Starting CNOGNP Fine-tuning for 10 epochs...

--- CNOGNP Epoch 1/10 ---
  Processing Particle 1/2:
      Starting 1-epoch SGD (CNOGNP Line 7, η=0.001)... Done. Avg Loss during SGD: 0.0040 | Time: 36.94s
Done. Fitness: 0.0452 (Avg Loss: 0.0040, Avg Grad Norm: 0.4128)
      New pbest for particle 1: 0.0452 (was 1098.9000)
  Updating gbest...
    New Global Best! Fitness: 0.0452 (was inf) from particle 1's pbest
  Processing Particle 2/2:
      Starting 1-epoch SGD (C

In [2]:
# -------------------- Final Evaluation --------------------
print("\n==> Evaluating the best model found by CNOGNP...")
# final_best_model = ResNet20().to(device)
if args.model == 'r20':
    final_best_model = ResNet20().to(device)
elif args.model == 'r32':
    final_best_model = ResNet32().to(device)
elif args.model == 'r44':
    final_best_model = ResNet44().to(device)
elif args.model == 'r56':
    final_best_model = ResNet56().to(device)
elif args.model == 'r110':
    final_best_model = ResNet110().to(device)
elif args.model == 'vit-t':
    final_best_model = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=192, depth=12, num_heads=3).to(device)
elif args.model == 'vit-s':
    final_best_model = VisionTransformer(img_size=32,patch_size=4,num_classes=100, embed_dim=384, depth=12, num_heads=6).to(device)
elif args.model == 'vgg16':
    final_best_model = VGG('VGG16').to(device)
elif args.model == 'vgg11':
    final_best_model = VGG('VGG11').to(device)


if gbest_state_dict is not None:
    final_best_model.load_state_dict(gbest_state_dict)
else:
    print("Error: Global best state dictionary was not set. Cannot evaluate.")
    exit()

# Use standard evaluate for final Loss/Acc comparison
print("--- Final Training Set Evaluation (using standard evaluate) ---")
final_train_loss, final_train_acc = evaluate(trainloader, final_best_model, "Train")
print("--- Final Test Set Evaluation (using standard evaluate) ---")
final_test_loss, final_test_acc = evaluate(testloader, final_best_model, "Test")
particle_eval_results = {} # Optional: dictionary to store results per particle
for i, particle in enumerate(particles):
    print(f"\n--- Evaluating Final State of Particle {i+1}/{args.num_particles} ---")
    # The model in particle['model'] holds the final state after all updates
    particle_model = particle['model']

    # Use standard evaluate for training set
    print(f"Particle {i+1} Train Set Evaluation:")
    train_loss, train_acc = evaluate(trainloader, particle_model, f"P{i+1} Train")

    # Use standard evaluate for test set
    print(f"Particle {i+1} Test  Set Evaluation:") # Added padding for alignment
    test_loss, test_acc = evaluate(testloader, particle_model, f"P{i+1} Test ")

    particle_eval_results[f'particle_{i+1}'] = {'train_loss': train_loss, 'train_acc': train_acc, 'test_loss': test_loss, 'test_acc': test_acc}

print("\n===== Initial Model Performance =====")
print(f"Initial Training Loss: {initial_train_loss:.4f}")
print(f"Initial Training Acc:  {initial_train_acc:.3f}%")
print(f"Initial Test Loss:     {initial_test_loss:.4f}")
print(f"Initial Test Acc:      {initial_test_acc:.3f}%")
print("====================================")

print("\n===== CNOGNP Fine-tuned Model Performance =====")
print(f"Achieved Global Best Fitness (Min Loss+λ||g||₂ during CNOGNP): {gbest_fitness:.4f}")
print(f"Final Eval Training Loss: {final_train_loss:.4f}") # Standard Loss
print(f"Final Eval Training Acc:  {final_train_acc:.3f}%")
print(f"Final Eval Test Loss:     {final_test_loss:.4f}") # Standard Loss
print(f"Final Eval Test Acc:      {final_test_acc:.3f}%")
print("===========================================")

# -------------------- Save Final Model --------------------
print(f'==> Saving final CNOGNP best model to {args.save_path}')
save_dir = os.path.dirname(args.save_path)
if save_dir and not os.path.exists(save_dir): os.makedirs(save_dir)
if gbest_state_dict is not None:
    torch.save(gbest_state_dict, args.save_path)
    print("Final best model saved.")
else:
    print("Error: Global best state dictionary was not set. Model not saved.")


==> Evaluating the best model found by CNOGNP...
--- Final Training Set Evaluation (using standard evaluate) ---
Train Eval | Loss: 0.0024 | Acc: 99.992% (49996/50000)
--- Final Test Set Evaluation (using standard evaluate) ---
Test  Eval | Loss: 0.2423 | Acc: 93.980% (9398/10000)

--- Evaluating Final State of Particle 1/2 ---
Particle 1 Train Set Evaluation:
P1 Train Eval | Loss: 0.0025 | Acc: 99.982% (49991/50000)
Particle 1 Test  Set Evaluation:
P1 Test  Eval | Loss: 0.2448 | Acc: 93.950% (9395/10000)

--- Evaluating Final State of Particle 2/2 ---
Particle 2 Train Set Evaluation:
P2 Train Eval | Loss: 0.0024 | Acc: 99.994% (49997/50000)
Particle 2 Test  Set Evaluation:
P2 Test  Eval | Loss: 0.2435 | Acc: 93.940% (9394/10000)

===== Initial Model Performance =====
Initial Training Loss: 0.0025
Initial Training Acc:  99.988%
Initial Test Loss:     0.2423
Initial Test Acc:      94.020%

===== CNOGNP Fine-tuned Model Performance =====
Achieved Global Best Fitness (Min Loss+λ||g||₂ du