# Evaluation of the LoRA Model on ImageNet Validation Set

## Load Lora Pretrained Model

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification
from peft import PeftModel
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

In [None]:
import torch
from transformers import AutoModelForImageClassification, AutoConfig
from peft import PeftModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- 1. ‰Ω†ÁöÑÂü∫Á°ÄÊ®°ÂûãÂêçÁß∞ (ËÆ≠ÁªÉ‰ΩøÁî®ÁöÑÂêåÊ¨æ) ----
base_model_name = "google/vit-large-patch16-224-in21k"

# ---- 2. ‰Ω†ÁöÑÊú¨Âú∞ÊùÉÈáçË∑ØÂæÑ ----
lora_path = "./vit_lora_r256_ema_best"                # ÁõÆÂΩïÔºå‰∏çÊòØÂçïÁã¨Êñá‰ª∂
classifier_path = "./vit_lora_r256_ema_best/vit_classifier_r256_ema_best.pt"

# ---- 3. Âä†ËΩΩÂü∫Á°ÄÊ®°ÂûãÔºåÊó† LoRA ----
config = AutoConfig.from_pretrained(base_model_name)
config.num_labels = 1000                # ImageNet 1k
base_model = AutoModelForImageClassification.from_pretrained(
    base_model_name,
    config=config
)

# ---- 4. Âä†ËΩΩ LoRA adapter ----
model = PeftModel.from_pretrained(
    base_model,
    lora_path,
    is_trainable=False                  # Êé®ÁêÜÊ®°Âºè
)

# ---- 5. Âä†ËΩΩËá™ÂÆö‰πâÂàÜÁ±ªÂ§¥ ----
state_dict = torch.load(classifier_path, map_location="cpu")
model.base_model.classifier.load_state_dict(state_dict)
# Â¶ÇÊûú‰Ω†ÂΩìÊó∂ËÆ≠ÁªÉÁî® torch.save(model.base_model.classifier.state_dict(), path)

model.to(device)
model.eval()

print("üöÄ Local model loaded successfully!")


'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /google/vit-large-patch16-224-in21k/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f3d52b45b50>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 84c73255-69a0-4794-928b-0a182d1a8440)')' thrown while requesting HEAD https://huggingface.co/google/vit-large-patch16-224-in21k/resolve/main/config.json
Retrying in 1s [Retry 1/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /google/vit-large-patch16-224-in21k/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f3d52aa3560>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: ca41b264-b080-4de9-ab90-aee0efd9e5b6)')' thrown while requesting HEAD https://huggingface.co/google/vit-large-patch16-2

üöÄ Local model loaded successfully!


In [4]:
VAL_DATASET_PATH = "/root/autodl-tmp/imagenet/val"

BATCH_SIZE = 64
NUM_WORKERS = 8 

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

val_dataset = datasets.ImageFolder(VAL_DATASET_PATH, transform=val_transform)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)



In [5]:
# ============================================================
# 6. Top-1 / Top-5 ËØÑ‰º∞
# ============================================================
def accuracy(output, target, topk=(1, 5)):
    """ËÆ°ÁÆó top-1 Âíå top-5"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)   # [B, maxk]
    pred = pred.t()                              # [maxk, B]
    correct = pred.eq(target.view(1, -1).expand_as(pred))  # [maxk, B]

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append((correct_k / batch_size).item())
    return res

# ËØÑ‰º∞Âæ™ÁéØ
top1_total = 0
top5_total = 0
total = 0

print("Evaluating...")
with torch.no_grad():
    for images, labels in tqdm(val_loader):
        images, labels = images.to(device), labels.to(device)

        outputs = model(pixel_values=images).logits
        top1, top5 = accuracy(outputs, labels)

        top1_total += top1 * images.size(0)
        top5_total += top5 * images.size(0)
        total += images.size(0)

top1_acc = top1_total / total
top5_acc = top5_total / total

print(f"\nüî• Top-1 Accuracy: {top1_acc * 100:.2f}%")
print(f"üî• Top-5 Accuracy: {top5_acc * 100:.2f}%")

Evaluating...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 782/782 [03:01<00:00,  4.30it/s]


üî• Top-1 Accuracy: 82.82%
üî• Top-5 Accuracy: 93.93%



