# import

In [1]:
# 标准库导入
import argparse
import json
import os
import pickle
import random
import sys
from collections import OrderedDict
from functools import partial

# 额外的标准库导入
import six
from PIL import Image

# 第三方库导入
import lmdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.nn.functional import one_hot

# 本地模块导入（如果有需要，可以在这里添加）
# from your_local_module import your_function


# config

In [2]:
# 手动设置参数
class Args:
    network = "resnet18"# choices=["resnet18", "resnet50", "ViT_B32"]
    dataset = "flowers102" # choices=["cifar10", "cifar100", "gtsrb", "svhn", "food101", "eurosat", "sun397", "UCF101", "flowers102", "DTD", "oxfordpets"]
    batchsize = 224
    seed = 42
    patch_size = 8
    attribute_channels = 3
    mapping_method = "ilm"
    data_path = "/dataset/"
    results_path = "./results"
    model_dir = "/model_pth/"

args = Args()

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def get_config(pretrained):
    '''

    Args:
        dataset: string of the dataset name
        pretrained: string of the pretrained model's name

    Returns:
        attribute_layers: the layer number of the attribute network
        epochs: number of training epochs
        lr: learning rate of reprogramming
        attr_lr: learning rate of attribute network
        attr_gamma: weight decay of attribute network
    '''
    epochs = 200
    lr = 0.01

    if pretrained == 'ViT_B32':
        attribute_layers = 6
        attr_lr = 0.001
        attr_gamma = 1
    else:
        attribute_layers = 5
        attr_lr = 0.01
        attr_gamma = 0.1

    return attribute_layers, epochs, lr, attr_lr, attr_gamma

# const

In [3]:
GTSRB_LABEL_MAP = {
    '0': '20_speed',
    '1': '30_speed',
    '2': '50_speed',
    '3': '60_speed',
    '4': '70_speed',
    '5': '80_speed',
    '6': '80_lifted',
    '7': '100_speed',
    '8': '120_speed',
    '9': 'no_overtaking_general',
    '10': 'no_overtaking_trucks',
    '11': 'right_of_way_crossing',
    '12': 'right_of_way_general',
    '13': 'give_way',
    '14': 'stop',
    '15': 'no_way_general',
    '16': 'no_way_trucks',
    '17': 'no_way_one_way',
    '18': 'attention_general',
    '19': 'attention_left_turn',
    '20': 'attention_right_turn',
    '21': 'attention_curvy',
    '22': 'attention_bumpers',
    '23': 'attention_slippery',
    '24': 'attention_bottleneck',
    '25': 'attention_construction',
    '26': 'attention_traffic_light',
    '27': 'attention_pedestrian',
    '28': 'attention_children',
    '29': 'attention_bikes',
    '30': 'attention_snowflake',
    '31': 'attention_deer',
    '32': 'lifted_general',
    '33': 'turn_right',
    '34': 'turn_left',
    '35': 'turn_straight',
    '36': 'turn_straight_right',
    '37': 'turn_straight_left',
    '38': 'turn_right_down',
    '39': 'turn_left_down',
    '40': 'turn_circle',
    '41': 'lifted_no_overtaking_general',
    '42': 'lifted_no_overtaking_trucks'
}

IMAGENETCLASSES = ['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', 'European fire salamander', 'common newt', 'eft', 'spotted salamander', 'axolotl', 'bullfrog', 'tree frog', 'tailed frog', 'loggerhead', 'leatherback turtle', 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'common iguana', 'American chameleon', 'whiptail', 'agama', 'frilled lizard', 'alligator lizard', 'Gila monster', 'green lizard', 'African chameleon', 'Komodo dragon', 'African crocodile', 'American alligator', 'triceratops', 'thunder snake', 'ringneck snake', 'hognose snake', 'green snake', 'king snake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor', 'rock python', 'Indian cobra', 'green mamba', 'sea snake', 'horned viper', 'diamondback', 'sidewinder', 'trilobite', 'harvestman', 'scorpion', 'black and gold garden spider', 'barn spider', 'garden spider', 'black widow', 'tarantula', 'wolf spider', 'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie chicken', 'peacock', 'quail', 'partridge', 'African grey', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', 'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser', 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', 'king crab', 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'American egret', 'bittern', 'crane', 'limpkin', 'European gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese spaniel', 'Maltese dog', 'Pekinese', 'Shih-Tzu', 'Blenheim spaniel', 'papillon', 'toy terrier', 'Rhodesian ridgeback', 'Afghan hound', 'basset', 'beagle', 'bloodhound', 'bluetick', 'black-and-tan coonhound', 'Walker hound', 'English foxhound', 'redbone', 'borzoi', 'Irish wolfhound', 'Italian greyhound', 'whippet', 'Ibizan hound', 'Norwegian elkhound', 'otterhound', 'Saluki', 'Scottish deerhound', 'Weimaraner', 'Staffordshire bullterrier', 'American Staffordshire terrier', 'Bedlington terrier', 'Border terrier', 'Kerry blue terrier', 'Irish terrier', 'Norfolk terrier', 'Norwich terrier', 'Yorkshire terrier', 'wire-haired fox terrier', 'Lakeland terrier', 'Sealyham terrier', 'Airedale', 'cairn', 'Australian terrier', 'Dandie Dinmont', 'Boston bull', 'miniature schnauzer', 'giant schnauzer', 'standard schnauzer', 'Scotch terrier', 'Tibetan terrier', 'silky terrier', 'soft-coated wheaten terrier', 'West Highland white terrier', 'Lhasa', 'flat-coated retriever', 'curly-coated retriever', 'golden retriever', 'Labrador retriever', 'Chesapeake Bay retriever', 'German short-haired pointer', 'vizsla', 'English setter', 'Irish setter', 'Gordon setter', 'Brittany spaniel', 'clumber', 'English springer', 'Welsh springer spaniel', 'cocker spaniel', 'Sussex spaniel', 'Irish water spaniel', 'kuvasz', 'schipperke', 'groenendael', 'malinois', 'briard', 'kelpie', 'komondor', 'Old English sheepdog', 'Shetland sheepdog', 'collie', 'Border collie', 'Bouvier des Flandres', 'Rottweiler', 'German shepherd', 'Doberman', 'miniature pinscher', 'Greater Swiss Mountain dog', 'Bernese mountain dog', 'Appenzeller', 'EntleBucher', 'boxer', 'bull mastiff', 'Tibetan mastiff', 'French bulldog', 'Great Dane', 'Saint Bernard', 'Eskimo dog', 'malamute', 'Siberian husky', 'dalmatian', 'affenpinscher', 'basenji', 'pug', 'Leonberg', 'Newfoundland', 'Great Pyrenees', 'Samoyed', 'Pomeranian', 'chow', 'keeshond', 'Brabancon griffon', 'Pembroke', 'Cardigan', 'toy poodle', 'miniature poodle', 'standard poodle', 'Mexican hairless', 'timber wolf', 'white wolf', 'red wolf', 'coyote', 'dingo', 'dhole', 'African hunting dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby', 'tiger cat', 'Persian cat', 'Siamese cat', 'Egyptian cat', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', 'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', 'ice bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', 'ladybug', 'ground beetle', 'long-horned beetle', 'leaf beetle', 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', 'grasshopper', 'cricket', 'walking stick', 'cockroach', 'mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', 'damselfly', 'admiral', 'ringlet', 'monarch', 'cabbage butterfly', 'sulphur butterfly', 'lycaenid', 'starfish', 'sea urchin', 'sea cucumber', 'wood rabbit', 'hare', 'Angora', 'hamster', 'porcupine', 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'sorrel', 'zebra', 'hog', 'wild boar', 'warthog', 'hippopotamus', 'ox', 'water buffalo', 'bison', 'ram', 'bighorn', 'ibex', 'hartebeest', 'impala', 'gazelle', 'Arabian camel', 'llama', 'weasel', 'mink', 'polecat', 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', 'siamang', 'guenon', 'patas', 'baboon', 'macaque', 'langur', 'colobus', 'proboscis monkey', 'marmoset', 'capuchin', 'howler monkey', 'titi', 'spider monkey', 'squirrel monkey', 'Madagascar cat', 'indri', 'Indian elephant', 'African elephant', 'lesser panda', 'giant panda', 'barracouta', 'eel', 'coho', 'rock beauty', 'anemone fish', 'sturgeon', 'gar', 'lionfish', 'puffer', 'abacus', 'abaya', 'academic gown', 'accordion', 'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', 'amphibian', 'analog clock', 'apiary', 'apron', 'ashcan', 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', 'ballpoint', 'Band Aid', 'banjo', 'bannister', 'barbell', 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'barrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'bathing cap', 'bath towel', 'bathtub', 'beach wagon', 'beacon', 'beaker', 'bearskin', 'beer bottle', 'beer glass', 'bell cote', 'bib', 'bicycle-built-for-two', 'bikini', 'binder', 'binoculars', 'birdhouse', 'boathouse', 'bobsled', 'bolo tie', 'bonnet', 'bookcase', 'bookshop', 'bottlecap', 'bow', 'bow tie', 'brass', 'brassiere', 'breakwater', 'breastplate', 'broom', 'bucket', 'buckle', 'bulletproof vest', 'bullet train', 'butcher shop', 'cab', 'caldron', 'candle', 'cannon', 'canoe', 'can opener', 'cardigan', 'car mirror', 'carousel', "carpenter's kit", 'carton', 'car wheel', 'cash machine', 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', 'cellular telephone', 'chain', 'chainlink fence', 'chain mail', 'chain saw', 'chest', 'chiffonier', 'chime', 'china cabinet', 'Christmas stocking', 'church', 'cinema', 'cleaver', 'cliff dwelling', 'cloak', 'clog', 'cocktail shaker', 'coffee mug', 'coffeepot', 'coil', 'combination lock', 'computer keyboard', 'confectionery', 'container ship', 'convertible', 'corkscrew', 'cornet', 'cowboy boot', 'cowboy hat', 'cradle', 'crane', 'crash helmet', 'crate', 'crib', 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', 'dial telephone', 'diaper', 'digital clock', 'digital watch', 'dining table', 'dishrag', 'dishwasher', 'disk brake', 'dock', 'dogsled', 'dome', 'doormat', 'drilling platform', 'drum', 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', 'electric locomotive', 'entertainment center', 'envelope', 'espresso maker', 'face powder', 'feather boa', 'file', 'fireboat', 'fire engine', 'fire screen', 'flagpole', 'flute', 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', 'four-poster', 'freight car', 'French horn', 'frying pan', 'fur coat', 'garbage truck', 'gasmask', 'gas pump', 'goblet', 'go-kart', 'golf ball', 'golfcart', 'gondola', 'gong', 'gown', 'grand piano', 'greenhouse', 'grille', 'grocery store', 'guillotine', 'hair slide', 'hair spray', 'half track', 'hammer', 'hamper', 'hand blower', 'hand-held computer', 'handkerchief', 'hard disc', 'harmonica', 'harp', 'harvester', 'hatchet', 'holster', 'home theater', 'honeycomb', 'hook', 'hoopskirt', 'horizontal bar', 'horse cart', 'hourglass', 'iPod', 'iron', "jack-o'-lantern", 'jean', 'jeep', 'jersey', 'jigsaw puzzle', 'jinrikisha', 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', 'laptop', 'lawn mower', 'lens cap', 'letter opener', 'library', 'lifeboat', 'lighter', 'limousine', 'liner', 'lipstick', 'Loafer', 'lotion', 'loudspeaker', 'loupe', 'lumbermill', 'magnetic compass', 'mailbag', 'mailbox', 'maillot', 'maillot', 'manhole cover', 'maraca', 'marimba', 'mask', 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine chest', 'megalith', 'microphone', 'microwave', 'military uniform', 'milk can', 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', 'mobile home', 'Model T', 'modem', 'monastery', 'monitor', 'moped', 'mortar', 'mortarboard', 'mosque', 'mosquito net', 'motor scooter', 'mountain bike', 'mountain tent', 'mouse', 'mousetrap', 'moving van', 'muzzle', 'nail', 'neck brace', 'necklace', 'nipple', 'notebook', 'obelisk', 'oboe', 'ocarina', 'odometer', 'oil filter', 'organ', 'oscilloscope', 'overskirt', 'oxcart', 'oxygen mask', 'packet', 'paddle', 'paddlewheel', 'padlock', 'paintbrush', 'pajama', 'palace', 'panpipe', 'paper towel', 'parachute', 'parallel bars', 'park bench', 'parking meter', 'passenger car', 'patio', 'pay-phone', 'pedestal', 'pencil box', 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'pick', 'pickelhaube', 'picket fence', 'pickup', 'pier', 'piggy bank', 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate', 'pitcher', 'plane', 'planetarium', 'plastic bag', 'plate rack', 'plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', 'pool table', 'pop bottle', 'pot', "potter's wheel", 'power drill', 'prayer rug', 'printer', 'prison', 'projectile', 'projector', 'puck', 'punching bag', 'purse', 'quill', 'quilt', 'racer', 'racket', 'radiator', 'radio', 'radio telescope', 'rain barrel', 'recreational vehicle', 'reel', 'reflex camera', 'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle', 'rocking chair', 'rotisserie', 'rubber eraser', 'rugby ball', 'rule', 'running shoe', 'safe', 'safety pin', 'saltshaker', 'sandal', 'sarong', 'sax', 'scabbard', 'scale', 'school bus', 'schooner', 'scoreboard', 'screen', 'screw', 'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe shop', 'shoji', 'shopping basket', 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski', 'ski mask', 'sleeping bag', 'slide rule', 'sliding door', 'slot', 'snorkel', 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', 'solar dish', 'sombrero', 'soup bowl', 'space bar', 'space heater', 'space shuttle', 'spatula', 'speedboat', 'spider web', 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', 'steel arch bridge', 'steel drum', 'stethoscope', 'stole', 'stone wall', 'stopwatch', 'stove', 'strainer', 'streetcar', 'stretcher', 'studio couch', 'stupa', 'submarine', 'suit', 'sundial', 'sunglass', 'sunglasses', 'sunscreen', 'suspension bridge', 'swab', 'sweatshirt', 'swimming trunks', 'swing', 'switch', 'syringe', 'table lamp', 'tank', 'tape player', 'teapot', 'teddy', 'television', 'tennis ball', 'thatch', 'theater curtain', 'thimble', 'thresher', 'throne', 'tile roof', 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toyshop', 'tractor', 'trailer truck', 'tray', 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', 'trolleybus', 'trombone', 'tub', 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle', 'upright', 'vacuum', 'vase', 'vault', 'velvet', 'vending machine', 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', 'wallet', 'wardrobe', 'warplane', 'washbasin', 'washer', 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', 'wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', 'wing', 'wok', 'wooden spoon', 'wool', 'worm fence', 'wreck', 'yawl', 'yurt', 'web site', 'comic book', 'crossword puzzle', 'street sign', 'traffic light', 'book jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'ice lolly', 'French loaf', 'bagel', 'pretzel', 'cheeseburger', 'hotdog', 'mashed potato', 'head cabbage', 'broccoli', 'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith', 'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', 'custard apple', 'pomegranate', 'hay', 'carbonara', 'chocolate sauce', 'dough', 'meat loaf', 'pizza', 'potpie', 'burrito', 'red wine', 'espresso', 'cup', 'eggnog', 'alp', 'bubble', 'cliff', 'coral reef', 'geyser', 'lakeside', 'promontory', 'sandbar', 'seashore', 'valley', 'volcano', 'ballplayer', 'groom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'hip', 'buckeye', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn', 'earthstar', 'hen-of-the-woods', 'bolete', 'ear', 'toilet tissue']

CIFAR10CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

CIFAR100CLASSES = ['beaver', 'dolphin', 'otter', 'seal', 'whale', 'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',
                   'orchids', 'poppies', 'roses', 'sunflowers', 'tulips', 'bottles', 'bowls', 'cans', 'cups', 'plates',
                   'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers', 'clock', 'computer keyboard', 'lamp', 'telephone', 'television',
                   'bed', 'chair', 'couch', 'table', 'wardrobe', 'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',
                   'bear', 'leopard', 'lion', 'tiger', 'wolf', 'bridge', 'castle', 'house', 'road', 'skyscraper',
                   'cloud', 'forest', 'mountain', 'plain', 'sea', 'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',
                   'fox', 'porcupine', 'possum', 'raccoon', 'skunk', 'crab', 'lobster', 'snail', 'spider', 'worm', 'baby',
                   'boy', 'girl', 'man', 'woman', 'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle', 'hamster', 'mouse',
                   'rabbit', 'shrew', 'squirrel', 'maple', 'oak', 'palm', 'pine', 'willow', 'bicycle', 'bus', 'motorcycle',
                   'pickup truck', 'train', 'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor']


IMAGENETNORMALIZE = {
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],
}


# prepare dataset

In [4]:
from collections import OrderedDict
import six
def loads_data(buf):
    """
    Args:
        buf: the output of `dumps`.
    """
    return pickle.loads(buf)
def prepare_additive_data(dataset, data_path, preprocess, test_process=None, batchsize = 256):
    data_path = os.path.join(data_path, dataset)
    if dataset == "cifar10":
        train_data = datasets.CIFAR10(root = data_path, train = True, download = False, transform = preprocess)
        test_data = datasets.CIFAR10(root = data_path, train = False, download = False, transform = test_process)
        class_names = refine_classnames(test_data.classes),
        loaders = {
            'train': DataLoader(train_data, batchsize, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, batchsize, shuffle = False, num_workers=2),
        }
        
    elif dataset == "cifar100":
        train_data = datasets.CIFAR100(root = data_path, train = True, download = False, transform = preprocess)
        test_data = datasets.CIFAR100(root = data_path, train = False, download = False, transform = test_process)
        class_names = refine_classnames(test_data.classes)
        loaders = {
            'train': DataLoader(train_data, batchsize, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, batchsize, shuffle = False, num_workers=8),
        }
        
    elif dataset == "svhn":
        train_data = datasets.SVHN(root = data_path, split="train", download = False, transform = preprocess)
        test_data = datasets.SVHN(root = data_path, split="test", download = False, transform = test_process)
        class_names = [f'{i}' for i in range(10)]
        loaders = {
            'train': DataLoader(train_data, batchsize, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, batchsize, shuffle = False, num_workers=8),
        }
        
    elif dataset == "gtsrb":
        train_data = datasets.GTSRB(root = data_path, split="train", download = True, transform = preprocess)
        test_data = datasets.GTSRB(root = data_path, split="test", download = True, transform = test_process)
        class_names = refine_classnames(list(GTSRB_LABEL_MAP.values()))
        loaders = {
            'train': DataLoader(train_data, batchsize, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, batchsize, shuffle = False, num_workers=8),
        }
        
    elif dataset in ["food101", "eurosat", "sun397", "UCF101", "flowers102"]:
        train_data = COOPLMDBDataset(root = data_path, split="train", transform = preprocess)
        test_data = COOPLMDBDataset(root = data_path, split="test", transform = test_process)
        class_names = refine_classnames(test_data.classes),
        loaders = {
            'train': DataLoader(train_data, batchsize, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, batchsize, shuffle = False, num_workers=8),
        }

    elif dataset in ["DTD", "oxfordpets"]:
        train_data = COOPLMDBDataset(root = data_path, split="train", transform = preprocess)
        test_data = COOPLMDBDataset(root = data_path, split="test", transform = test_process)
        class_names = refine_classnames(test_data.classes),
        loaders = {
            'train': DataLoader(train_data, batchsize, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, batchsize, shuffle = False, num_workers=8),
        }

    else:
        raise NotImplementedError(f"{dataset} not supported")
    return loaders, class_names, train_data, test_data

class LMDBDataset(data.Dataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        super().__init__()
        db_path = os.path.join(root, f"{split}.lmdb")
        self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = loads_data(txn.get(b'__len__'))
            self.keys = loads_data(txn.get(b'__keys__'))

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])

        unpacked = loads_data(byteflow)

        # load img
        imgbuf = unpacked[0]
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf)

        # load label
        target = unpacked[1]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        # return img, target
        return img, target

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'
    
class COOPLMDBDataset(LMDBDataset):
    def __init__(self, root, split="train", transform=None) -> None:
        super().__init__(root, split, transform=transform)
        with open(os.path.join(root, "split.json")) as f:
            split_file = json.load(f)
        idx_to_class = OrderedDict(sorted({s[-2]: s[-1] for s in split_file["test"]}.items()))
        self.classes = list(idx_to_class.values())

def refine_classnames(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ')
    return class_names

# label mapping

## flm_calculation

In [5]:
def get_dist_matrix(fx, y):
    fx = one_hot(torch.argmax(fx, dim = -1), num_classes=fx.size(-1))
    dist_matrix = [fx[y==i].sum(0).unsqueeze(1) for i in range(len(y.unique()))]
    dist_matrix = torch.cat(dist_matrix, dim=1)
    return dist_matrix

def predictive_distribution_based_multi_label_mapping(dist_matrix, mlm_num: int):
    assert mlm_num * dist_matrix.size(1) <= dist_matrix.size(0), "source label number not enough for mapping"
    mapping_matrix = torch.zeros_like(dist_matrix, dtype=int)
    dist_matrix_flat = dist_matrix.flatten()
    for _ in range(mlm_num * dist_matrix.size(1)):
        loc = dist_matrix_flat.argmax().item()
        loc = [loc // dist_matrix.size(1), loc % dist_matrix.size(1)]
        mapping_matrix[loc[0], loc[1]] = 1
        dist_matrix[loc[0]] = -1
        if mapping_matrix[:, loc[1]].sum() == mlm_num:
            dist_matrix[:, loc[1]] = -1
    return mapping_matrix

def generate_label_mapping_by_frequency(visual_prompt, network, data_loader, mapping_num = 1): # mapping_num=1: 1V1 match
    device = next(visual_prompt.parameters()).device
    if hasattr(network, "eval"):
        network.eval()
    fx0s = []
    ys = []
    pbar = tqdm(data_loader, total=len(data_loader), desc=f"Frequency Label Mapping", ncols=100) if len(data_loader) > 20 else data_loader
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            fx0 = network(visual_prompt(x))
        fx0s.append(fx0)
        ys.append(y)
    fx0s = torch.cat(fx0s).cpu().float()
    ys = torch.cat(ys).cpu().int()
    if ys.size(0) != fx0s.size(0):
        assert fx0s.size(0) % ys.size(0) == 0
        ys = ys.repeat(int(fx0s.size(0) / ys.size(0)))
    dist_matrix = get_dist_matrix(fx0s, ys)
    pairs = torch.nonzero(predictive_distribution_based_multi_label_mapping(dist_matrix, mapping_num)) # (C, C) 原来i类对应现在的j类, j=0,1,...,C
    mapping_sequence = pairs[:, 0][torch.sort(pairs[:, 1]).indices.tolist()]
    return mapping_sequence

## label_mapping

In [6]:
def label_mapping_base(logits, mapping_sequence):
    modified_logits = logits[:, mapping_sequence]
    return modified_logits

# data loaders

In [7]:
# 设备配置
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# 设置图像大小
imgsize = 384 if args.network == "ViT_B32" else 224

# 定义预处理
train_preprocess = transforms.Compose([
    transforms.Resize((imgsize + 32, imgsize + 32)),
    transforms.RandomCrop(imgsize),
    transforms.RandomHorizontalFlip(),
    transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std']),
])
test_preprocess = transforms.Compose([
    transforms.Resize((imgsize, imgsize)),
    transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std']),
])

# 数据加载
loaders, class_names,train_dataset,test_dataset = prepare_additive_data(
    dataset=args.dataset,
    data_path=args.data_path,         # 设置为默认的 'dataset' 文件夹
    preprocess=train_preprocess,
    test_process=test_preprocess,
    batchsize = args.batchsize
#     download=True
)
if len(class_names) == 1:
    class_names = class_names[0]
    
print("类别名称:", class_names)

# 示例：遍历训练数据
for images, labels in loaders['train']:
    print(f"图像批次尺寸: {images.size()}, 标签尺寸: {labels.size()}")
    break  # 仅显示第一个批次

attribute_layers, epochs, lr, attr_lr, attr_gamma = get_config(args.network)

save_path = os.path.join(args.results_path, args.dataset + args.network + args.mapping_method + str(args.seed) + str(args.attribute_channels) + str(attribute_layers) + str(args.patch_size))

类别名称: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
图像批次尺寸: torch.Size([224, 3, 224, 224]), 标签尺寸: torch.Size([224])


# model

## VisualPrompt 

In [8]:
class AttributeNet(nn.Module):
    def __init__(self, layers=5, patch_size=8, channels=3):
        super(AttributeNet, self).__init__()
        self.layers = layers
        self.patch_size = patch_size
        self.channels = channels

        self.pooling = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(8)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(16, 32, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(32)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(32, 64, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU(inplace=True)
        if self.layers == 5 and self.channels == 3:
            self.conv6 = nn.Conv2d(64, 3, 3, 1, 1)
        elif self.layers == 6:
            self.conv5 = nn.Conv2d(64, 128, 3, 1, 1)
            self.bn5 = nn.BatchNorm2d(128)
            self.relu5 = nn.ReLU(inplace=True)

            if self.channels == 3:
                self.conv6 = nn.Conv2d(128, 3, 3, 1, 1)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu1(y)
        if self.patch_size in [2, 4, 8, 16, 32]:
            y = self.pooling(y)
        y = self.conv2(y)
        y = self.bn2(y)
        y = self.relu2(y)
        if self.patch_size in [4, 8, 16, 32]:
            y = self.pooling(y)
        y = self.conv3(y)
        y = self.bn3(y)
        y = self.relu3(y)
        if self.patch_size in [8, 16, 32]:
            y = self.pooling(y)
        y = self.conv4(y)
        y = self.bn4(y)
        y = self.relu4(y)
        if self.patch_size in [16, 32]:
            y = self.pooling(y)
        if self.layers == 6:
            y = self.conv5(y)
            y = self.bn5(y)
            y = self.relu5(y)
            if self.patch_size == 32:
                y = self.pooling(y)

        if self.channels == 3:
            y = self.conv6(y)
        elif self.channels == 1:
            y = torch.mean(y, dim=1)
        return y

class InstancewiseVisualPrompt(nn.Module):
    def __init__(self, size, layers=5, patch_size=8, channels=3):
        '''
        Args:
            size: input image size
            layers: the number of layers of mask-training CNN
            patch_size: the size of patches with the same mask value
            channels: 3 means that the mask value for RGB channels are different, 1 means the same
            keep_watermark: whether to keep the reprogram (\\delta) in the model
        '''
        super(InstancewiseVisualPrompt, self).__init__()
        if layers not in [5, 6]:
            raise ValueError("Input layer number is not supported")
        if patch_size not in [1, 2, 4, 8, 16, 32]:
            raise ValueError("Input patch size is not supported")
        if channels not in [1, 3]:
            raise ValueError("Input channel number is not supported")
        if patch_size == 32 and layers != 6:
            raise ValueError("Input layer number and patch size are conflict with each other")

        # Set the attribute mask CNN
        self.patch_num = int(size / patch_size)
        self.imagesize = size
        self.patch_size = patch_size
        self.channels = channels
        self.priority = AttributeNet(layers, patch_size, channels)

        # Set reprogram (\delta) according to the image size
        self.size = size
        self.program = torch.nn.Parameter(data=torch.zeros(3, size, size))


    def forward(self, x):
        attention = self.priority(x).view(-1, self.channels, self.patch_num * self.patch_num, 1).expand(-1, 3, -1, self.patch_size * self.patch_size).view(-1, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size).transpose(3, 4)
        attention = attention.reshape(-1, 3, self.imagesize, self.imagesize)
        x = x + self.program * attention
        return x

##  VisualPrompt-compensation

In [9]:
import torch.nn as nn

class AttributeNet_activation(nn.Module):
    def __init__(self, layers=5, patch_size=8, channels=3, activation=nn.ReLU):
        # activation=nn.LeakyReLU; activation=nn.ELU; activation=nn.GELU; activation=nn.SiLU; activation=nn.Mish
        super(AttributeNet_activation, self).__init__()
        self.layers = layers
        self.patch_size = patch_size
        self.channels = channels

        # 初始化激活函数
        self.activation = activation(inplace=True) if hasattr(activation, 'inplace') else activation()

        self.pooling = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(8)
        self.relu1 = self.activation
        self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = self.activation
        self.conv3 = nn.Conv2d(16, 32, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(32)
        self.relu3 = self.activation
        self.conv4 = nn.Conv2d(32, 64, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = self.activation
        if self.layers == 5 and self.channels == 3:
            self.conv6 = nn.Conv2d(64, 3, 3, 1, 1)
        elif self.layers == 6:
            self.conv5 = nn.Conv2d(64, 128, 3, 1, 1)
            self.bn5 = nn.BatchNorm2d(128)
            self.relu5 = self.activation

            if self.channels == 3:
                self.conv6 = nn.Conv2d(128, 3, 3, 1, 1)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu1(y)
        if self.patch_size in [2, 4, 8, 16, 32]:
            y = self.pooling(y)
        y = self.conv2(y)
        y = self.bn2(y)
        y = self.relu2(y)
        if self.patch_size in [4, 8, 16, 32]:
            y = self.pooling(y)
        y = self.conv3(y)
        y = self.bn3(y)
        y = self.relu3(y)
        if self.patch_size in [8, 16, 32]:
            y = self.pooling(y)
        y = self.conv4(y)
        y = self.bn4(y)
        y = self.relu4(y)
        if self.patch_size in [16, 32]:
            y = self.pooling(y)
        if self.layers == 6:
            y = self.conv5(y)
            y = self.bn5(y)
            y = self.relu5(y)
            if self.patch_size == 32:
                y = self.pooling(y)

        if self.channels == 3:
            y = self.conv6(y)
        elif self.channels == 1:
            y = torch.mean(y, dim=1)
        return y

class VisualPrompt_compensation(nn.Module):
    def __init__(self, size, layers=5, patch_size=8, channels=3, activation=nn.Mish):
        super(VisualPrompt_compensation, self).__init__()
        if layers not in [5, 6]:
            raise ValueError("Input layer number is not supported")
        if patch_size not in [1, 2, 4, 8, 16, 32]:
            raise ValueError("Input patch size is not supported")
        if channels not in [1, 3]:
            raise ValueError("Input channel number is not supported")
        if patch_size == 32 and layers != 6:
            raise ValueError("Input layer number and patch size are conflict with each other")

        # Set the attribute mask CNN
        self.patch_num = int(size / patch_size)
        self.imagesize = size
        self.patch_size = patch_size
        self.channels = channels
        self.priority = AttributeNet_activation(layers, patch_size, channels, activation=activation)

        # Set reprogram (\delta) according to the image size
        self.size = size
        self.program = torch.nn.Parameter(data=torch.zeros(3, size, size))

    def forward(self, x):
        attention = self.priority(x).view(-1, self.channels, self.patch_num * self.patch_num, 1).expand(-1, 3, -1, self.patch_size * self.patch_size).view(-1, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size).transpose(3, 4)
        attention = attention.reshape(-1, 3, self.imagesize, self.imagesize)
        Prompt_compensation = self.program * attention
        return Prompt_compensation

## Proportional Adjustment Controller

In [10]:
class PAC(nn.Module):
    def __init__(self):
        super(PAC, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, 1)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = PAC().to(device)
# x = torch.rand(256, 3, 224, 224).to(device)
# output = model(x)
# print(output.shape)  # 输出: torch.Size([256, 1])


##  SMM_Compensation

In [11]:
class SMM_compensation(nn.Module):
    def __init__(self, network, visual_prompt, visual_prompt_comp, pac):
        super(SMM_compensation, self).__init__()
        self.network = network
        self.visual_prompt = visual_prompt
        self.visual_prompt_comp = visual_prompt_comp
        self.mininet = pac

    def forward(self, x, label_mapping, label_mapping_comp) -> torch.Tensor:
        # SMM 模型
        f_frozen = self.visual_prompt(x)  # f_frozen = x + self.program * attention  
        logit_main = self.network(f_frozen)
        X_main = label_mapping(logit_main)
        
        # SMM 模型的补偿部分
        f_compensation = self.visual_prompt_comp(x)
        logit_comp = self.network(f_compensation + x)
        X_comp = label_mapping_comp(logit_comp)
        
        # 用于 alpha 学习的 pac
        alpha = self.mininet(x).view(x.size(0), 1, 1, 1)  # 将 alpha 重塑为 [B, 1, 1, 1]
       
        Prompt_all = (alpha * (f_frozen - x) + (1 - alpha) * f_compensation) + x
        output_class = label_mapping(self.network(Prompt_all))
        
        return output_class, X_main, X_comp, alpha.view(-1)

# Loss

In [12]:
class CustomLoss(nn.Module):
    def __init__(self, lambda_ce=1.0, lambda_mse=1.0, lambda_alpha=1.0):
        """
        初始化 CustomLoss 类。

        参数:
            lambda_ce (float): 交叉熵损失的权重。
            lambda_mse (float): 补偿损失的权重。
            lambda_alpha (float): Alpha 损失的权重。
        """
        super(CustomLoss, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss()
        self.lambda_ce = lambda_ce
        self.lambda_mse = lambda_mse
        self.lambda_alpha = lambda_alpha
        self.c = nn.Parameter(torch.tensor(1.0))

    def forward(self, output_class, X_main, X_comp, alpha, y, updating_part='all'):
        """
        前向传播计算总损失及各个子损失。

        参数:
            output_class (torch.Tensor): 模型的分类输出，形状为 (batch_size, num_classes)。
            X_main (torch.Tensor): 主流的特征输出，形状为 (batch_size, num_classes)。
            X_comp (torch.Tensor): 补偿流的特征输出，形状为 (batch_size, num_classes)。
            alpha (torch.Tensor): 模型预测的 alpha 值，形状为 (batch_size,)。
            y (torch.Tensor): 真实的类别标签，形状为 (batch_size,)。
            updating_part (str): 当前更新的部分，'visual_prompt' 或 'visual_prompt_comp'。

        返回:
            total_loss (torch.Tensor): 综合损失。
        """
        if updating_part == 'visual_prompt':
            # 仅计算主流的交叉熵损失
            loss_ce_main = self.cross_entropy(X_main, y)
            return loss_ce_main

        elif updating_part == 'visual_prompt_comp':
            # 计算完整的损失
            # 1. 交叉熵分类损失
            loss_ce_all = self.cross_entropy(output_class, y)

            # 2. 补偿损失
            Y_main = F.one_hot(y, num_classes=output_class.size(1)).float()  # One-Hot 编码
            X_main_probs = F.softmax(X_main, dim=1)  # 主流概率分布
            Y_diff = Y_main - X_main_probs  # 标签差异
            X_comp_probs = F.softmax(X_comp, dim=1)  # 补偿流概率分布
            loss_mse = self.mse(X_comp_probs, Y_diff)  # 均方误差损失

            # 3. Alpha 损失（L2 范数）
            output_class_probs = F.softmax(output_class, dim=1)  # 分类输出概率分布
            diff = torch.norm(output_class_probs - X_main_probs, p=2, dim=1)  # 残差的 L2 范数
            alpha_standard = torch.sigmoid(self.c*diff)  # 标准 alpha
            loss_alpha = self.mse(alpha, alpha_standard)  # Alpha 损失

            # 4. 综合损失
            total_loss = (self.lambda_ce * loss_ce_all +
                          self.lambda_mse * loss_mse +
                          self.lambda_alpha * loss_alpha)
            return total_loss

        else:
            # 计算所有损失（默认情况）
            loss_ce_main = self.cross_entropy(X_main, y)
            loss_ce_all = self.cross_entropy(output_class, y)
            Y_main = F.one_hot(y, num_classes=output_class.size(1)).float()
            X_main_probs = F.softmax(X_main, dim=1)
            Y_diff = Y_main - X_main_probs
            X_comp_probs = F.softmax(X_comp, dim=1)
            loss_mse = self.mse(X_comp_probs, Y_diff)
            output_class_probs = F.softmax(output_class, dim=1)
            diff = torch.norm(output_class_probs - X_main_probs, p=2, dim=1)
            alpha_standard = torch.sigmoid(self.c*diff)
            loss_alpha = self.mse(alpha, alpha_standard)
            total_loss = (self.lambda_ce * (loss_ce_main + loss_ce_all) +
                          self.lambda_mse * loss_mse +
                          self.lambda_alpha * loss_alpha)
            return total_loss

In [13]:
# Make dir
os.makedirs(save_path, exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)
logger = SummaryWriter(save_path)

set_seed(args.seed)

In [14]:
# 加载预训练的模型并加载本地权重
if args.network == "resnet18":
    from torchvision.models import resnet18, ResNet18_Weights
    network = resnet18(weights=None)  # 初始化模型，不加载预训练权重
    weight_path = os.path.join(args.model_dir, "resnet18_weights.pth")
elif args.network == "resnet50":
    from torchvision.models import resnet50, ResNet50_Weights
    network = resnet50(weights=None)  # 初始化模型，不加载预训练权重
    weight_path = os.path.join(args.model_dir, "resnet50_weights.pth")
elif args.network == "ViT_B32":
    from pytorch_pretrained_vit import ViT
    network = ViT('B_32_imagenet1k', pretrained=False)  # 初始化模型，不加载预训练权重
    weight_path = os.path.join(args.model_dir, "ViT_B32_weights.pth")
else:
    raise NotImplementedError(f"{args.network} is not supported")

# 检查权重文件是否存在
if not os.path.exists(weight_path):
    raise FileNotFoundError(f"权重文件未找到: {weight_path}")

# 加载权重
network.load_state_dict(torch.load(weight_path, map_location=device,weights_only=True), strict=False)
network.to(device)

print(f"模型权重已加载至: {weight_path}")
print(f"模型设备: {device}")

# 冻结网络参数
network.requires_grad_(False)
network.eval()

# 4. 初始化模型组件
visual_prompt = InstancewiseVisualPrompt(imgsize, attribute_layers, args.patch_size, args.attribute_channels).to(device)
visual_prompt_comp = VisualPrompt_compensation(imgsize, attribute_layers, args.patch_size, args.attribute_channels).to(device)

pac = PAC().to(device)
model = SMM_compensation(network, visual_prompt, visual_prompt_comp, pac).to(device)

模型权重已加载至: /home/yyh/Desktop/2022/ly/ICML/Code/model_pth/resnet18_weights.pth
模型设备: cuda:0


In [27]:
# 动态构建检查点文件名，添加运行后缀
dataset_name = args.dataset  # 'DTD'
checkpoint_path = os.path.join(save_path, f'best_AL_mish_{args.network}_{dataset_name}_2.pth')
checkpoint_path_ckpt = os.path.join(save_path, f'ckpt_AL_mish_{args.network}_{dataset_name}.pth')

if_test = False

# 10. 检查是否有检查点存在
if os.path.exists(checkpoint_path):
    print(f"加载检查点：{checkpoint_path}")
    # 模型参数导入
    checkpoint = torch.load(checkpoint_path, map_location=device,weights_only=True)
    model.visual_prompt.load_state_dict(checkpoint['visual_prompt_dict'])
    model.visual_prompt_comp.load_state_dict(checkpoint['visual_prompt_comp_dict'])
    model.mininet.load_state_dict(checkpoint['mininet_dict'])
    mapping_sequence = checkpoint.get("mapping_sequence", None)
    mapping_sequence_comp = checkpoint.get("mapping_sequence_comp", None)
    # 继续下一个epoch
    epoch_start = checkpoint.get("epoch", 0) + 1  
    best_acc = checkpoint.get("best_acc", 0.0)
    train_loss_history = checkpoint.get("train_loss_history", None)
    test_acc_history = checkpoint.get("test_acc_history", None)
    label_mapping_comp = partial(label_mapping_base, mapping_sequence=mapping_sequence_comp)
    label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)  
    if_test = True
    
model.eval()

加载检查点：./results/cifar10resnet18ilm42358/best_AL_mish_resnet18_cifar10_1.pth


In [None]:
if if_test:
    # 测试
    model.eval()
    total_num_test = 0
    true_num_main_test = 0
    true_num_fx_test = 0
    pbar_test = tqdm(loaders['test'], total=len(loaders['test']), desc=f"Epo {0} Testing", ncols=100)
    with torch.no_grad():
        for x, y in pbar_test:
            if x.get_device() == -1:
                x, y = x.to(device), y.to(device)
            fx, X_main, X_comp, alpha = model(x, label_mapping, label_mapping_comp)

            # 始终获取 X_main 和 fx 的预测结果
            pred_labels_main = torch.argmax(X_main, dim=1)
            pred_labels_fx = torch.argmax(fx, dim=1)

            # 计算准确率
            acc_main = pred_labels_main.eq(y).sum().item()
            acc_fx = pred_labels_fx.eq(y).sum().item()

            # 更新累积准确数
            total_num_test += y.size(0)
            true_num_main_test += acc_main
            true_num_fx_test += acc_fx

            # 计算并更新进度条显示，包含当前更新部分
            overall_acc_main = 100 * true_num_main_test / total_num_test
            overall_acc_fx = 100 * true_num_fx_test / total_num_test
            pbar_test.set_postfix_str(f" Acc_main: {overall_acc_main:.2f}% | Acc_fx: {overall_acc_fx:.2f}%")

# code for train

In [29]:
num_runs = 3  # 定义训练次数

for run in range(1, num_runs + 1):
    print(f"\n=== 开始第 {run} 次训练 ===\n")
    
    # 3. 加载预训练的模型并加载本地权重
    if args.network == "resnet18":
        from torchvision.models import resnet18
        network = resnet18(weights=None)  # 初始化模型，不加载预训练权重
        weight_path = os.path.join(args.model_dir, "resnet18_weights.pth")
    elif args.network == "resnet50":
        from torchvision.models import resnet50
        network = resnet50(weights=None)  # 初始化模型，不加载预训练权重
        weight_path = os.path.join(args.model_dir, "resnet50_weights.pth")
    elif args.network == "ViT_B32":
        from pytorch_pretrained_vit import ViT
        network = ViT('B_32_imagenet1k', pretrained=False)  # 初始化模型，不加载预训练权重
        weight_path = os.path.join(args.model_dir, "ViT_B32_weights.pth")
    else:
        raise NotImplementedError(f"{args.network} is not supported")
    
    # 检查权重文件是否存在
    if not os.path.exists(weight_path):
        raise FileNotFoundError(f"权重文件未找到: {weight_path}")
    
    # 加载权重
    network.load_state_dict(torch.load(weight_path, map_location=device, weights_only=True), strict=False)
    network.to(device)
    
    print(f"模型权重已加载至: {weight_path}")
    print(f"模型设备: {device}")
    
    # 冻结网络参数
    network.requires_grad_(False)
    network.eval()
    
    # 4. 初始化模型组件
    visual_prompt = InstancewiseVisualPrompt(imgsize, attribute_layers, args.patch_size, args.attribute_channels).to(device)
    visual_prompt_comp = VisualPrompt_compensation(imgsize, attribute_layers, args.patch_size, args.attribute_channels).to(device)
    
    mininet = PAC().to(device)
    model = SMM_compensation(network, visual_prompt, visual_prompt_comp, mininet).to(device)
    
    # 5. 设置超参数
    lr_visual_prompt_comp = 1e-2  # 优化器1的学习率
    lr_visual_prompt = 1e-2       # 优化器2的学习率
    gamma_prog = 0.8              # 优化器1的gamma
    gamma_prompt = 0.8            # 优化器2的gamma
    cycle_length = 55             # 每55个epoch为一个完整周期（25+25+5）
    total_epochs = 600            # 总训练周期数（根据需要调整）
    criterion = CustomLoss()
    milestones = list(range(5, total_epochs + 1, 50))
    
    # 6. 定义优化器1：更新 visual_prompt_comp、criterion 和 mininet
    optimizer1 = torch.optim.Adam([
        {'params': model.visual_prompt_comp.parameters()},
        {'params': criterion.parameters()},
        {'params': model.mininet.parameters()}
    ], lr=lr_visual_prompt_comp)
    
    scheduler1 = torch.optim.lr_scheduler.MultiStepLR(
        optimizer1,
        milestones=milestones,
        gamma=gamma_prog,
        verbose=True
    )
    
    # 7. 定义优化器2：更新 visual_prompt
    optimizer2 = torch.optim.Adam([
        {'params': model.visual_prompt.parameters()},
    ], lr=lr_visual_prompt)
    
    scheduler2 = torch.optim.lr_scheduler.MultiStepLR(
        optimizer2,
        milestones=milestones,
        gamma=gamma_prompt,
        verbose=True
    )
    
    # 8. 初始化训练历史记录和收敛检测参数
    # 这些变量将根据是否加载检查点进行条件初始化
    # 初始化为 None，稍后根据情况进行赋值
    train_loss_history = None
    test_acc_history = None
    best_acc = 0.0
    scaler = None
    patience = 20
    threshold = 0.01
    convergence_counter = None
    converged = None
    previous_best_acc = None
    epoch_start_convergence = None
    
    # 9. 动态构建检查点文件名，添加运行后缀
    dataset_name = args.dataset  # 'DTD' 等
    checkpoint_path = os.path.join(save_path, f'best_AL_mish_{args.network}_{dataset_name}_{run}.pth')
    checkpoint_path_ckpt = os.path.join(save_path, f'ckpt_AL_mish_{args.network}_{dataset_name}_{run}.pth')
    
    # 10. 检查是否有检查点存在
    if os.path.exists(checkpoint_path):
        print(f"加载检查点：{checkpoint_path}")
        # 加载检查点
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.visual_prompt.load_state_dict(checkpoint['visual_prompt_dict'])
        model.visual_prompt_comp.load_state_dict(checkpoint['visual_prompt_comp_dict'])
        model.mininet.load_state_dict(checkpoint['mininet_dict'])
        mapping_sequence = checkpoint.get("mapping_sequence", None)
        mapping_sequence_comp = checkpoint.get("mapping_sequence_comp", None)
        # 继续下一个epoch
        epoch_start = checkpoint.get("epoch", 0) + 1  
        best_acc = checkpoint.get("best_acc", 0.0)
        optimizer1.load_state_dict(checkpoint['optimizer1_state_dict'])
        optimizer2.load_state_dict(checkpoint['optimizer2_state_dict'])
        scheduler1.load_state_dict(checkpoint['scheduler1_state_dict'])
        scheduler2.load_state_dict(checkpoint['scheduler2_state_dict'])
        criterion.load_state_dict(checkpoint['criterion_state_dict'])
        
        # 恢复训练历史和收敛检测参数
        train_loss_history = checkpoint.get("train_loss_history", [])
        test_acc_history = checkpoint.get("test_acc_history", [])
        converged = checkpoint.get("converged", False)
        convergence_counter = checkpoint.get("convergence_counter", 0)
        previous_best_acc = checkpoint.get("previous_best_acc", 0.0)
        epoch_start_convergence = checkpoint.get("epoch_start_convergence", 0)
        
        # 恢复 scaler 状态（如果保存了）
        if 'scaler_state_dict' in checkpoint:
            scaler = GradScaler()
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        else:
            scaler = GradScaler()
        
        print(f"成功加载检查点，继续从 epoch {epoch_start} 开始训练。")
    else:
        print("没有找到检查点，开始从头训练。")
        # 初始化训练历史记录和收敛检测参数
        train_loss_history = []
        test_acc_history = []
        best_acc = 0.0
        scaler = GradScaler()
        
        patience = 20            # 连续多少个epoch没有显著提升则认为收敛
        threshold = 0.01         # 精度提升小于此阈值时认为没有显著提升
        convergence_counter = 0  # 收敛计数器
        converged = False        # 是否已收敛
        previous_best_acc = 0.0  # 记录上一个最佳准确率
        epoch_start_convergence = 0  # 收敛后的起始epoch
        epoch_start = 0  # 从第0个epoch开始
    
    # 11. 开始训练
    for epoch in range(epoch_start, total_epochs):  # 从加载的 epoch 开始
        if args.mapping_method == 'ilm':
            mapping_sequence_comp = generate_label_mapping_by_frequency(
                model.visual_prompt_comp, model.network, loaders['train']
            )
            label_mapping_comp = partial(label_mapping_base, mapping_sequence=mapping_sequence_comp)
            mapping_sequence = generate_label_mapping_by_frequency(
                model.visual_prompt, model.network, loaders['train']
            )
            label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
    
        model.visual_prompt.train()
        model.visual_prompt_comp.train()
        model.mininet.train()
    
        total_num = 0
        true_num_main_train = 0
        true_num_fx_train = 0
        loss_sum = 0
    
        # 决定当前阶段
        if not converged:
            # 预训练阶段：仅训练主分支
            updating_part = 'visual_prompt'
            optimizer_main = [optimizer2]
            scheduler_main = [scheduler2]
        else:
            # 交替训练阶段：周期性更新
            epoch_in_cycle = (epoch - epoch_start_convergence) % cycle_length
            if epoch_in_cycle < 25:
                updating_part = 'visual_prompt_comp'
                optimizer_main = [optimizer1]
                scheduler_main = [scheduler1]
            elif epoch_in_cycle < 50:
                updating_part = 'visual_prompt'
                optimizer_main = [optimizer2]
                scheduler_main = [scheduler2]
            else:
                updating_part = 'all'
                optimizer_main = [optimizer1, optimizer2]
                scheduler_main = [scheduler1, scheduler2]
    
        pbar = tqdm(loaders['train'], total=len(loaders['train']),
                    desc=f"Epo {epoch}", ncols=100)
        for x, y in pbar:
            if x.get_device() == -1:
                x, y = x.to(device), y.to(device)
            pbar.set_description_str(f"Epo {epoch}", refresh=True)
    
            # 梯度清零
            for opt in optimizer_main:
                opt.zero_grad()
    
            with autocast():  # 混合精度训练
                fx, X_main, X_comp, alpha = model(x, label_mapping, label_mapping_comp)
                loss = criterion(fx, X_main, X_comp, alpha, y, updating_part=updating_part)
    
            scaler.scale(loss).backward()
    
            # 优化器步进
            for opt in optimizer_main:
                scaler.step(opt)
            scaler.update()
    
            # 始终获取 X_main 和 fx 的预测结果
            pred_labels_main = torch.argmax(X_main, dim=1)
            pred_labels_fx = torch.argmax(fx, dim=1)
    
            # 计算准确率
            acc_main = pred_labels_main.eq(y).sum().item()
            acc_fx = pred_labels_fx.eq(y).sum().item()
    
            # 更新累积准确数
            total_num += y.size(0)
            true_num_main_train += acc_main
            true_num_fx_train += acc_fx
    
            # 更新累积损失
            loss_sum += loss.item() * y.size(0)
    
            # 计算并更新进度条显示，包含当前更新部分
            overall_acc_main = 100 * true_num_main_train / total_num
            overall_acc_fx = 100 * true_num_fx_train / total_num
            pbar.set_postfix_str(f"Updating: {updating_part} | Acc_main: {overall_acc_main:.2f}% | Acc_fx: {overall_acc_fx:.2f}%")
    
        # 步进学习率调度器
        for sch in scheduler_main:
            sch.step()
    
        # 记录训练损失和准确率
        train_loss = loss_sum / total_num
        if not converged:
            train_acc_main = true_num_main_train / total_num
            train_loss_history.append(train_loss)
            # 假设使用 TensorBoard 记录日志
            logger.add_scalar("train/acc_main", train_acc_main, epoch)
            logger.add_scalar("train/loss", train_loss, epoch)
        else:
            train_acc_main = true_num_main_train / total_num
            train_acc_fx = true_num_fx_train / total_num
            train_loss_history.append(train_loss)
            logger.add_scalar("train/acc_main", train_acc_main, epoch)
            logger.add_scalar("train/acc_fx", train_acc_fx, epoch)
            logger.add_scalar("train/loss", train_loss, epoch)

        # 测试
        model.eval()
        total_num_test = 0
        true_num_main_test = 0
        true_num_fx_test = 0
        pbar_test = tqdm(loaders['test'], total=len(loaders['test']), desc=f"Epo {epoch} Testing", ncols=100)
        with torch.no_grad():
            for x, y in pbar_test:
                if x.get_device() == -1:
                    x, y = x.to(device), y.to(device)
                fx, X_main, X_comp, alpha = model(x, label_mapping, label_mapping_comp)
    
                # 始终获取 X_main 和 fx 的预测结果
                pred_labels_main = torch.argmax(X_main, dim=1)
                pred_labels_fx = torch.argmax(fx, dim=1)
    
                # 计算准确率
                acc_main = pred_labels_main.eq(y).sum().item()
                acc_fx = pred_labels_fx.eq(y).sum().item()
    
                # 更新累积准确数
                total_num_test += y.size(0)
                true_num_main_test += acc_main
                true_num_fx_test += acc_fx
    
                # 计算并更新进度条显示，包含当前更新部分
                overall_acc_main = 100 * true_num_main_test / total_num_test
                overall_acc_fx = 100 * true_num_fx_test / total_num_test
                pbar_test.set_postfix_str(f"Updating: {updating_part} | Acc_main: {overall_acc_main:.2f}% | Acc_fx: {overall_acc_fx:.2f}%")
    
        if not converged:
            test_acc_history.append((overall_acc_main, None))
            logger.add_scalar("test/acc_main", overall_acc_main, epoch)
        else:
            test_acc_history.append((overall_acc_main, overall_acc_fx))
            logger.add_scalar("test/acc_main", overall_acc_main, epoch)
            logger.add_scalar("test/acc_fx", overall_acc_fx, epoch)
    
        # 定义 state_dict，确保在每个 epoch 中都被定义
        state_dict = {
            "visual_prompt_dict": model.visual_prompt.state_dict(),
            "visual_prompt_comp_dict": model.visual_prompt_comp.state_dict(),
            "mininet_dict": model.mininet.state_dict(),
            "mapping_sequence": mapping_sequence,
            "mapping_sequence_comp": mapping_sequence_comp,
            "epoch": epoch,
            "best_acc": best_acc,
            "optimizer1_state_dict": optimizer1.state_dict(),
            "optimizer2_state_dict": optimizer2.state_dict(),
            "scheduler1_state_dict": scheduler1.state_dict(),
            "scheduler2_state_dict": scheduler2.state_dict(),
            "criterion_state_dict": criterion.state_dict(),
            "train_loss_history": train_loss_history,
            "test_acc_history": test_acc_history,
            # 添加收敛相关的状态
            "converged": converged,
            "convergence_counter": convergence_counter,
            "previous_best_acc": previous_best_acc,
            "epoch_start_convergence": epoch_start_convergence,
            "scaler_state_dict": scaler.state_dict(),  # 保存 scaler 状态
        }
    
        # 保存最佳模型
        if not converged:
            current_best_acc = overall_acc_main
        else:
            current_best_acc = overall_acc_fx  # 或根据需要选择
    
        if current_best_acc > best_acc:
            best_acc = current_best_acc
            state_dict["best_acc"] = best_acc  # 更新 best_acc
            torch.save(state_dict, checkpoint_path)
            print(f"保存最佳模型，准确率: {best_acc:.2f}%")
    
        # 保存周期性检查点
        torch.save(state_dict, checkpoint_path_ckpt)
    
        # 收敛检测逻辑
        if not converged:
            if epoch == epoch_start:
                previous_best_acc = overall_acc_main
                print(f"Epoch {epoch+1}: 初始最佳准确率 = {previous_best_acc:.2f}%")
            else:
                if (overall_acc_main - previous_best_acc) > threshold:
                    previous_best_acc = overall_acc_main
                    convergence_counter = 0
                    print(f"Epoch {epoch+1}: 准确率提升到 {overall_acc_main:.2f}%，更新最佳准确率。")
                else:
                    convergence_counter += 1
                    print(f"Epoch {epoch+1}: 准确率未提升，收敛计数器 = {convergence_counter}/{patience}")
                    if convergence_counter >= patience:
                        converged = True
                        epoch_start_convergence = epoch + 1  # 记录收敛时的 epoch
                        print(f"主分支准确率已收敛（连续 {patience} 个epoch无显著提升），开始交替训练阶段。")
    
        # 打印日志，包含当前更新部分和两种准确率，以及是否已收敛和数据集名称，并补充最高验证准确率
        convergence_status = "已收敛" if converged else "未收敛"
        print(f"Epoch {epoch+1} 完成，数据集: {dataset_name}, 收敛状态: {convergence_status}, 更新部分: {updating_part}, 验证准确率 - Acc_main: {overall_acc_main:.2f}%, Acc_fx: {overall_acc_fx:.2f}%, 最高准确率: {best_acc:.2f}%")
    
    print(f"第 {run} 次训练完成！\n")