In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"


In [None]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224-in21k", use_fast=True)
model = AutoModelForImageClassification.from_pretrained("google/vit-large-patch16-224-in21k", use_auth_token=os.environ['HF_TOKEN'])

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
import torch
import torch.nn as nn

model.classifier = torch.nn.Linear(model.config.hidden_size, 1000)  # ImageNet 有 1000 个类别
model.config.num_labels = 1000

In [4]:
print(model.classifier.weight.shape)
print(model.classifier.bias.shape)


torch.Size([1000, 1024])
torch.Size([1000])


In [5]:
print(model)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
           

In [6]:
# print the model 
print(model)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
           

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
           

In [None]:
from peft import LoraConfig, get_peft_model, TaskType   

target_modules = ['query', 'value']


lora_config = LoraConfig(
    # 分类任务类型
    task_type=TaskType.FEATURE_EXTRACTION, # 或 TaskType.IMAGE_CLASSIFICATION (取决于 PEFT 版本)
    r=32,          # LoRA 秩 (Rank): 决定引入的参数量和表达能力
    lora_alpha=32, # 缩放因子 (通常设为 r 的两倍或相等)
    target_modules=target_modules,
    lora_dropout=0.05,
    bias="none"
)


model_peft = get_peft_model(model, lora_config)

print(f"--- LoRA Model Summary ---")
model_peft.print_trainable_parameters()

--- LoRA Model Summary ---
trainable params: 1,572,864 || all params: 305,899,496 || trainable%: 0.5142


In [9]:
for name, param in model_peft.named_parameters():
    if "lora_" not in name and "classifier" not in name:
        param.requires_grad = False


## Benchmark the base model performance on ImageNet-Mini

In [10]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os

# --- 1. 路径配置 ---
# 假设您的训练集和验证集路径如下
TRAIN_DIR = "/root/autodl-tmp/imagenet/train"
VAL_DIR = "/root/autodl-tmp/imagenet/val"

# --- 2. 预处理 (Transforms) ---
# ImageNet 标准归一化参数
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# 模型输入尺寸与归一化保持一致
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(processor.size["height"]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

val_transform = transforms.Compose([
    transforms.Resize((processor.size["height"], processor.size["width"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

# --- 3. 辅助函数 (解决 .JPEG 大小写问题) ---
def is_valid_image(path):
    """
    自定义函数，确保只加载有效图片文件 (包括大写 .JPEG)
    """
    return path.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))

# --- 4. 加载 Datasets (使用 ImageFolder) ---

print("正在加载训练集...")
try:
    train_dataset = ImageFolder(
        root=TRAIN_DIR,
        transform=train_transform,
        is_valid_file=is_valid_image # 使用自定义函数处理大小写
    )
    
    print(f"训练集加载成功: {len(train_dataset)} 张图片, {len(train_dataset.classes)} 个类别。")

    print("\n正在加载验证集...")
    val_dataset = ImageFolder(
        root=VAL_DIR,
        transform=val_transform,
        is_valid_file=is_valid_image # 同样使用
    )
    
    print(f"验证集加载成功: {len(val_dataset)} 张图片, {len(val_dataset.classes)} 个类别。")

    # --- 5. 创建 DataLoaders ---
    BATCH_SIZE = 128 # 根据您的 GPU 显存调整 (例如 64, 128, 256)
    NUM_WORKERS = 8  # 根据您的 CPU 核心数调整 (例如 4, 8, 16)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,      # 训练集需要打乱
        num_workers=NUM_WORKERS,
        pin_memory=True,   # 锁定内存以加快 GPU 传输
        drop_last=True     # 丢弃最后一个不完整的批次
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,     # 验证集不需要打乱
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    print(f"\n✅ DataLoader 准备就绪 (Batch Size: {BATCH_SIZE})")

    # (可选) 检查类别映射是否一致
    # print(train_dataset.class_to_idx)
    
except FileNotFoundError as e:
    print(f"❌ 错误: 找不到数据目录。请检查路径配置: {e}")
except Exception as e:
    print(f"❌ 加载数据时发生错误: {e}")

正在加载训练集...
训练集加载成功: 1281167 张图片, 1000 个类别。

正在加载验证集...
验证集加载成功: 50000 张图片, 1000 个类别。

✅ DataLoader 准备就绪 (Batch Size: 128)


In [11]:
# benchmark 

from tqdm import tqdm

def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Evaluating", unit="batch"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(pixel_values = images)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [12]:
# acc = evaluate_model(model_peft, val_loader, device)
# print(f"\n--- LoRA Model on ImageNet Validation Set ---")
# print(f"Accuracy: {acc:.2f}%")

```shell
--- LoRA Model on ImageNet Validation Set ---
Accuracy: 0.072%
```

## Train model to improve performance further with ImageNet dataset

In [13]:
import os
import math
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# --------------------------------------
# 1️⃣ 环境优化选项
# --------------------------------------
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/root/.torch_inductor"
os.environ["TORCH_LOGS"] = "nothing"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# --------------------------------------
# 2️⃣ 设备和日志
# --------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter(log_dir="/root/tf-logs")

# --------------------------------------
# 3️⃣ 模型与优化器
# --------------------------------------
model_peft = torch.compile(model_peft)
criterion = nn.CrossEntropyLoss().to(device)

optimizer = AdamW(
    filter(lambda p: p.requires_grad, model_peft.parameters()),
    lr=1e-4,
    betas=(0.85, 0.98),
    weight_decay=0.01
)

scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=5,
    T_mult=2,
    eta_min=1e-5
)

# AMP 自动切换
use_bf16 = torch.cuda.is_bf16_supported()
scaler = torch.amp.GradScaler("cuda", enabled=not use_bf16)


num_epochs = 20

# --------------------------------------
# 5️⃣ 训练循环
# --------------------------------------
global_step = 0
for epoch in range(num_epochs):
    model_peft.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for step, (images, labels) in enumerate(pbar):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16 if use_bf16 else torch.float16):
            outputs = model_peft(pixel_values=images)
            loss = criterion(outputs.logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_peft.parameters(), 1.0)
        optimizer.step()
        scheduler.step(epoch + step / len(train_loader))  # ✅ 每步更新学习率

        running_loss += loss.item()
        writer.add_scalar("Train/Loss", loss.item(), global_step)
        writer.add_scalar("LR/base", optimizer.param_groups[0]["lr"], global_step)
        global_step += 1

        pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{optimizer.param_groups[0]['lr']:.2e}")

    avg_loss = running_loss / len(train_loader)
    print(f"\nEpoch [{epoch+1}/{num_epochs}] | Loss: {avg_loss:.4f}")

    acc = evaluate_model(model_peft, val_loader, device)
    print(f"Validation Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    writer.add_scalar("Val/Acc", acc, epoch)

writer.close()


Epoch 1/20: 100%|██████████| 10009/10009 [38:56<00:00,  4.28batch/s, loss=5.4130, lr=9.14e-05]



Epoch [1/20] | Loss: 5.9252


Evaluating: 100%|██████████| 1563/1563 [01:36<00:00, 16.24batch/s]


Validation Accuracy after Epoch 1: 48.72%


Epoch 2/20: 100%|██████████| 10009/10009 [38:48<00:00,  4.30batch/s, loss=5.0408, lr=6.89e-05]



Epoch [2/20] | Loss: 5.1413


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.16batch/s]


Validation Accuracy after Epoch 2: 55.90%


Epoch 3/20: 100%|██████████| 10009/10009 [38:48<00:00,  4.30batch/s, loss=4.8073, lr=4.11e-05]



Epoch [3/20] | Loss: 4.8610


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.13batch/s]


Validation Accuracy after Epoch 3: 58.92%


Epoch 4/20: 100%|██████████| 10009/10009 [38:47<00:00,  4.30batch/s, loss=4.5924, lr=1.86e-05]



Epoch [4/20] | Loss: 4.7357


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.16batch/s]


Validation Accuracy after Epoch 4: 59.89%


Epoch 5/20: 100%|██████████| 10009/10009 [38:47<00:00,  4.30batch/s, loss=4.8518, lr=1.00e-05]



Epoch [5/20] | Loss: 4.6816


Evaluating: 100%|██████████| 1563/1563 [01:30<00:00, 17.18batch/s]


Validation Accuracy after Epoch 5: 60.22%


Epoch 6/20: 100%|██████████| 10009/10009 [38:44<00:00,  4.31batch/s, loss=4.5591, lr=9.78e-05]



Epoch [6/20] | Loss: 4.6137


Evaluating: 100%|██████████| 1563/1563 [01:30<00:00, 17.20batch/s]


Validation Accuracy after Epoch 6: 61.59%


Epoch 7/20: 100%|██████████| 10009/10009 [38:43<00:00,  4.31batch/s, loss=4.2288, lr=9.14e-05]



Epoch [7/20] | Loss: 4.4708


Evaluating: 100%|██████████| 1563/1563 [01:30<00:00, 17.18batch/s]


Validation Accuracy after Epoch 7: 62.82%


Epoch 8/20: 100%|██████████| 10009/10009 [38:45<00:00,  4.30batch/s, loss=4.1850, lr=8.15e-05]



Epoch [8/20] | Loss: 4.3675


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.16batch/s]


Validation Accuracy after Epoch 8: 63.49%


Epoch 9/20: 100%|██████████| 10009/10009 [38:48<00:00,  4.30batch/s, loss=4.2849, lr=6.89e-05]



Epoch [9/20] | Loss: 4.2910


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.13batch/s]


Validation Accuracy after Epoch 9: 64.36%


Epoch 10/20: 100%|██████████| 10009/10009 [38:41<00:00,  4.31batch/s, loss=4.1636, lr=5.50e-05]



Epoch [10/20] | Loss: 4.2330


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.16batch/s]


Validation Accuracy after Epoch 10: 64.55%


Epoch 11/20: 100%|██████████| 10009/10009 [38:46<00:00,  4.30batch/s, loss=4.0951, lr=4.11e-05]



Epoch [11/20] | Loss: 4.1913


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.14batch/s]


Validation Accuracy after Epoch 11: 64.98%


Epoch 12/20: 100%|██████████| 10009/10009 [38:49<00:00,  4.30batch/s, loss=4.3349, lr=2.86e-05]



Epoch [12/20] | Loss: 4.1600


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.12batch/s]


Validation Accuracy after Epoch 12: 65.22%


Epoch 13/20: 100%|██████████| 10009/10009 [38:49<00:00,  4.30batch/s, loss=4.1201, lr=1.86e-05]



Epoch [13/20] | Loss: 4.1376


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.13batch/s]


Validation Accuracy after Epoch 13: 65.51%


Epoch 14/20: 100%|██████████| 10009/10009 [38:49<00:00,  4.30batch/s, loss=3.9671, lr=1.22e-05]



Epoch [14/20] | Loss: 4.1227


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.13batch/s]


Validation Accuracy after Epoch 14: 65.52%


Epoch 15/20: 100%|██████████| 10009/10009 [38:50<00:00,  4.30batch/s, loss=4.0330, lr=1.00e-05]



Epoch [15/20] | Loss: 4.1138


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.14batch/s]


Validation Accuracy after Epoch 15: 65.47%


Epoch 16/20: 100%|██████████| 10009/10009 [38:49<00:00,  4.30batch/s, loss=4.4089, lr=9.94e-05]



Epoch [16/20] | Loss: 4.1317


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.12batch/s]


Validation Accuracy after Epoch 16: 65.39%


Epoch 17/20: 100%|██████████| 10009/10009 [38:47<00:00,  4.30batch/s, loss=4.1983, lr=9.78e-05]



Epoch [17/20] | Loss: 4.0904


Evaluating: 100%|██████████| 1563/1563 [01:31<00:00, 17.17batch/s]


Validation Accuracy after Epoch 17: 65.56%


Epoch 18/20:   8%|▊         | 833/10009 [03:14<35:42,  4.28batch/s, loss=3.8758, lr=9.76e-05]


KeyboardInterrupt: 

In [None]:
model_peft.save_pretrained("./vit_lora_imagenet_model")
