In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import ViTForImageClassification, ViTFeatureExtractor, Trainer, TrainingArguments, ViTImageProcessor,AutoImageProcessor
import os
from PIL import Image
from sklearn.metrics import accuracy_score
import numpy as np
import pyarrow.parquet as pq
import io
from sklearn.model_selection import train_test_split
from peft import LoraConfig, TaskType, get_peft_model,PeftModel, PeftConfig
#from transformers import LoRAConfig, LoRAAdapter
from utils import get_mnist_data, get_EuroSAT_data,CustomTensorDataset, get_cifar10_data, get_car_data,get_fruits_data, get_GTSRB_data, get_DTD_data, get_resis_data, get_grabage_data, get_plants_data
import random
from tqdm import tqdm

def set_random():
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    # 如果使用 GPU，还需固定 GPU 相关随机数
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)  # 用于多 GPU 情况
    # 确保卷积操作等确定性（针对特定卷积算法）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def test_result():
    sum1 = 0
    set_random()
    for i in tqdm(range(1000)):
        with torch.no_grad():
            outputs = model(test_dataset[i]["pixel_values"].unsqueeze(0).cuda())
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=-1).item()
        if predicted_class != int(test_dataset[i]['labels']):
            #print(predicted_class, int(test_dataset[i]['labels']))
            sum1 += 1
    print(1-sum1/1000)
    return 1-sum1/1000
    

def compute_accuracy(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}
model_dir = "./vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_dir, torch_dtype="auto")
feature_extractor = ViTImageProcessor.from_pretrained(model_dir)

# lora_config = LoraConfig(
#         #task_type=TaskType.CAUSAL_LM,
#         #task_type=TaskType.SEQ_CLS,
#         task_type=TaskType.FEATURE_EXTRACTION,
#         target_modules=[ "intermediate.dense", "output.dense"],
#         inference_mode=False,  # 训练模式
#         r=8,  # Lora 秩
#         lora_alpha=32,  # 等效于lr=lr*lora_alpha/r
#         lora_dropout=0.1
    # )

lora_config = LoraConfig(
        r=16,
        lora_alpha=16,
        target_modules=[ "intermediate.dense", "output.dense"],  #["query", "value"]
        lora_dropout=0.1,
        bias="none",
        modules_to_save=["classifier"],
    )

model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CustomTensorDataset(Dataset):
    def __init__(self, tensors, labels, mode= "PNG_code", transform=None): #"JPG"  Tensor
        self.tensors = tensors
        self.labels = labels
        self.transform = transform
        self.mode = mode
        assert len(self.tensors) == len(self.labels)  # "张量数量和标签数量必须相同"
    def __len__(self):
        return len(self.tensors)
    def __getitem__(self, idx):
        label = self.labels[idx]
        return {"pixel_values": self.tensors[idx], "labels": torch.tensor(label)}

In [7]:
task_functions = {"mnist": get_mnist_data,"EuroSAT": get_EuroSAT_data,"cifar10": get_cifar10_data,"car": get_car_data,"fruits": get_fruits_data,"GTSRB": get_GTSRB_data,"DTD": get_DTD_data,"resis": get_resis_data,"grabage": get_grabage_data, "plants": get_plants_data}
train_list = []
train_label = []
# test_list = []
# test_label = []
for j, task in enumerate([ "mnist", "EuroSAT", "cifar10", "car", "fruits","GTSRB", "DTD","resis", "grabage","plants"]):
    train_dataset, test_dataset = task_functions[task]()
    train_list += [train_dataset[i]['pixel_values'] for i in range(min(5000, len(train_dataset)))]
    train_label += [j] * min(5000, len(train_dataset))
    # test_list += [test_dataset[i]['pixel_values'] for i in range(min(1000, len(test_dataset)))]
    # test_label += [j] * min(1000, len(test_dataset))

In [8]:

train_dataset = CustomTensorDataset(train_list, train_label)
test_dataset = CustomTensorDataset(test_list, test_label)

In [10]:
from torchvision import datasets, transforms, models
model1 = models.resnet18(pretrained=False, num_classes=10).cuda()  # 修改分类数
model1.load_state_dict(torch.load("./router_model.pth"))

  model1.load_state_dict(torch.load("./router_model.pth"))


<All keys matched successfully>

In [14]:
# 2. 加载训练/验证数据
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
train_dataset_ = [[i['pixel_values'],i["labels"]] for i in train_dataset]
test_dataset_ = [[i['pixel_values'],i["labels"]] for i in test_dataset]
train_loader = DataLoader(train_dataset_+test_dataset_, batch_size=32, shuffle=True)
val_loader = DataLoader(test_dataset_, batch_size=32, shuffle=False)

# 3. 加载 ResNet18 模型（不使用 ImageNet 预训练）
# model = models.resnet18(pretrained=False, num_classes=10)  # 修改分类数
# model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# 4. 损失函数 + 优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 5. 训练
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_acc = 0
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        # print(images,labels)
        # print(images["pixel_values"],labels["labels"])
        # break
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss:.4f}")

    # 验证
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    print(f"Validation Accuracy: {acc:.2f}%")
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "./router_model.pth")
        print(f"✅ Saved new best model with accuracy {acc:.2f}%")

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch [1/50], Loss: 50.5148


  2%|▏         | 1/50 [01:09<57:09, 69.98s/it]

Validation Accuracy: 99.42%
✅ Saved new best model with accuracy 99.42%
Epoch [2/50], Loss: 42.6233


  4%|▍         | 2/50 [02:19<55:56, 69.93s/it]

Validation Accuracy: 99.59%
✅ Saved new best model with accuracy 99.59%
Epoch [3/50], Loss: 39.4879


  6%|▌         | 3/50 [03:29<54:36, 69.72s/it]

Validation Accuracy: 99.58%
Epoch [4/50], Loss: 34.2911


  8%|▊         | 4/50 [04:40<53:44, 70.11s/it]

Validation Accuracy: 99.63%
✅ Saved new best model with accuracy 99.63%
Epoch [5/50], Loss: 31.7832


 10%|█         | 5/50 [05:50<52:45, 70.34s/it]

Validation Accuracy: 99.52%
Epoch [6/50], Loss: 34.2322


 12%|█▏        | 6/50 [07:01<51:41, 70.49s/it]

Validation Accuracy: 99.70%
✅ Saved new best model with accuracy 99.70%
Epoch [7/50], Loss: 29.3944


 14%|█▍        | 7/50 [08:12<50:42, 70.75s/it]

Validation Accuracy: 99.77%
✅ Saved new best model with accuracy 99.77%
Epoch [8/50], Loss: 29.8419


 16%|█▌        | 8/50 [09:23<49:33, 70.79s/it]

Validation Accuracy: 99.75%
Epoch [9/50], Loss: 24.3104


 18%|█▊        | 9/50 [10:32<47:56, 70.15s/it]

Validation Accuracy: 99.62%
Epoch [10/50], Loss: 30.4025


 20%|██        | 10/50 [11:42<46:38, 69.97s/it]

Validation Accuracy: 99.57%
Epoch [11/50], Loss: 22.1049


 22%|██▏       | 11/50 [12:51<45:24, 69.87s/it]

Validation Accuracy: 99.64%
Epoch [12/50], Loss: 26.2783


 24%|██▍       | 12/50 [14:01<44:08, 69.71s/it]

Validation Accuracy: 99.79%
✅ Saved new best model with accuracy 99.79%
Epoch [13/50], Loss: 21.0047


 26%|██▌       | 13/50 [15:11<43:06, 69.91s/it]

Validation Accuracy: 99.41%
Epoch [14/50], Loss: 21.5771


 28%|██▊       | 14/50 [16:22<42:04, 70.13s/it]

Validation Accuracy: 99.47%
Epoch [15/50], Loss: 20.3882


 30%|███       | 15/50 [17:32<41:01, 70.32s/it]

Validation Accuracy: 99.84%
✅ Saved new best model with accuracy 99.84%
Epoch [16/50], Loss: 21.9488


 32%|███▏      | 16/50 [18:43<39:53, 70.40s/it]

Validation Accuracy: 99.85%
✅ Saved new best model with accuracy 99.85%
Epoch [17/50], Loss: 19.9341


 34%|███▍      | 17/50 [19:53<38:40, 70.33s/it]

Validation Accuracy: 99.94%
✅ Saved new best model with accuracy 99.94%
Epoch [18/50], Loss: 20.3127


 36%|███▌      | 18/50 [21:03<37:22, 70.07s/it]

Validation Accuracy: 99.78%
Epoch [19/50], Loss: 18.4443


 38%|███▊      | 19/50 [22:13<36:17, 70.25s/it]

Validation Accuracy: 99.90%
Epoch [20/50], Loss: 17.8759


 40%|████      | 20/50 [23:23<35:04, 70.16s/it]

Validation Accuracy: 99.86%
Epoch [21/50], Loss: 18.5726


 42%|████▏     | 21/50 [24:34<34:03, 70.45s/it]

Validation Accuracy: 99.94%
Epoch [22/50], Loss: 14.0819


 44%|████▍     | 22/50 [25:44<32:49, 70.33s/it]

Validation Accuracy: 99.91%
Epoch [23/50], Loss: 18.8460


 46%|████▌     | 23/50 [26:54<31:35, 70.20s/it]

Validation Accuracy: 99.75%
Epoch [24/50], Loss: 14.8107


 48%|████▊     | 24/50 [28:02<30:09, 69.61s/it]

Validation Accuracy: 99.94%
Epoch [25/50], Loss: 14.5296


 50%|█████     | 25/50 [29:11<28:55, 69.42s/it]

Validation Accuracy: 99.62%
Epoch [26/50], Loss: 19.6297


 52%|█████▏    | 26/50 [30:21<27:49, 69.55s/it]

Validation Accuracy: 99.74%
Epoch [27/50], Loss: 10.2360


 54%|█████▍    | 27/50 [31:33<26:57, 70.32s/it]

Validation Accuracy: 99.21%
Epoch [28/50], Loss: 14.7380


 56%|█████▌    | 28/50 [32:43<25:44, 70.22s/it]

Validation Accuracy: 99.90%
Epoch [29/50], Loss: 14.8460


 58%|█████▊    | 29/50 [33:53<24:29, 69.97s/it]

Validation Accuracy: 99.96%
✅ Saved new best model with accuracy 99.96%
Epoch [30/50], Loss: 15.2014


 60%|██████    | 30/50 [35:02<23:13, 69.66s/it]

Validation Accuracy: 99.69%
Epoch [31/50], Loss: 12.8336


 62%|██████▏   | 31/50 [36:11<22:03, 69.66s/it]

Validation Accuracy: 99.31%
Epoch [32/50], Loss: 11.4480


 64%|██████▍   | 32/50 [37:21<20:55, 69.73s/it]

Validation Accuracy: 99.98%
✅ Saved new best model with accuracy 99.98%
Epoch [33/50], Loss: 14.3290


 66%|██████▌   | 33/50 [38:31<19:43, 69.59s/it]

Validation Accuracy: 99.89%
Epoch [34/50], Loss: 12.5065


 68%|██████▊   | 34/50 [39:39<18:29, 69.34s/it]

Validation Accuracy: 99.81%
Epoch [35/50], Loss: 11.1200


 70%|███████   | 35/50 [40:49<17:20, 69.38s/it]

Validation Accuracy: 99.80%
Epoch [36/50], Loss: 14.5078


 72%|███████▏  | 36/50 [41:58<16:11, 69.37s/it]

Validation Accuracy: 99.90%
Epoch [37/50], Loss: 11.4677


 74%|███████▍  | 37/50 [43:07<15:01, 69.36s/it]

Validation Accuracy: 99.85%
Epoch [38/50], Loss: 11.0377


 76%|███████▌  | 38/50 [44:18<13:58, 69.84s/it]

Validation Accuracy: 99.90%
Epoch [39/50], Loss: 10.6018


 78%|███████▊  | 39/50 [45:28<12:45, 69.63s/it]

Validation Accuracy: 99.74%
Epoch [40/50], Loss: 12.2402


 80%|████████  | 40/50 [46:38<11:37, 69.76s/it]

Validation Accuracy: 99.91%
Epoch [41/50], Loss: 11.9961


 82%|████████▏ | 41/50 [47:48<10:28, 69.84s/it]

Validation Accuracy: 99.72%
Epoch [42/50], Loss: 11.5267


 84%|████████▍ | 42/50 [48:57<09:17, 69.68s/it]

Validation Accuracy: 99.81%
Epoch [43/50], Loss: 7.1949


 86%|████████▌ | 43/50 [50:08<08:10, 70.08s/it]

Validation Accuracy: 99.91%
Epoch [44/50], Loss: 10.1764


 88%|████████▊ | 44/50 [51:17<06:59, 69.84s/it]

Validation Accuracy: 99.90%
Epoch [45/50], Loss: 12.5272


 90%|█████████ | 45/50 [52:27<05:48, 69.70s/it]

Validation Accuracy: 99.90%
Epoch [46/50], Loss: 9.3731


 92%|█████████▏| 46/50 [53:37<04:39, 69.76s/it]

Validation Accuracy: 99.91%
Epoch [47/50], Loss: 11.0547


 94%|█████████▍| 47/50 [54:46<03:29, 69.68s/it]

Validation Accuracy: 99.93%
Epoch [48/50], Loss: 8.9699


 96%|█████████▌| 48/50 [55:55<02:19, 69.62s/it]

Validation Accuracy: 99.94%
Epoch [49/50], Loss: 6.9718


 98%|█████████▊| 49/50 [57:05<01:09, 69.57s/it]

Validation Accuracy: 99.90%
Epoch [50/50], Loss: 10.4520


100%|██████████| 50/50 [58:15<00:00, 69.90s/it]

Validation Accuracy: 99.95%





In [None]:
model.eval()
correct = 0
total = 0

test_dataset_ = [[i['pixel_values'],i["labels"]] for i in test_dataset]
train_loader = DataLoader(train_dataset_+test_dataset_, batch_size=32, shuffle=True)
val_loader = DataLoader(test_dataset_, batch_size=32, shuffle=False)
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        for i in range(len(labels)):
            if preds[i] != labels[i]:
                print(preds[i], labels[i])
        # break
        total += labels.size(0)

acc = 100 * correct / total
print(f"Validation Accuracy: {acc:.2f}%")