In [1]:
!pip install transformers accelerate evaluate datasets peft -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m471.0/480.6 kB[0m [31m20.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [6]:
import os
import torch
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, get_scheduler
from peft import AdaLoraConfig, get_peft_model

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model and dataset configuration
model_name_or_path = "google/vit-base-patch16-224-in21k"
dataset_name = "food101"
batch_size = 16
num_epochs = 8
learning_rate = 1e-4

# Load the dataset
dataset = load_dataset(dataset_name)
processor = AutoImageProcessor.from_pretrained(model_name_or_path)
normalize = transforms.Normalize(mean=processor.image_mean, std=processor.image_std)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])

def preprocess_function(examples):
    examples["image"] = [transform(img.convert("RGB")) for img in examples["image"]]
    return examples

train_dataset = dataset["train"].with_transform(preprocess_function)
eval_dataset = dataset["validation"].with_transform(preprocess_function)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=4)

total_training_steps = len(train_dataloader) * num_epochs

# Configure AdaLoRA
peft_config = AdaLoraConfig(
    init_r=12,
    target_r=8,
    beta1=0.85,
    beta2=0.85,
    tinit=200,
    tfinal=1000,
    deltaT=10,
    lora_alpha=32,
    lora_dropout=0.1,
    inference_mode=False,
    target_modules = ['query', 'key', 'value']
)

peft_config.total_step = total_training_steps

# Load the model and apply AdaLoRA
model = AutoModelForImageClassification.from_pretrained(model_name_or_path, num_labels=101)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model.to(device)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=len(train_dataloader) * num_epochs
)

# Training loop with AdaLoRA update
global_step = 0
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}"):
        inputs, labels = batch["image"].to(device), batch["label"].to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # Update AdaLoRA matrices and allocate budget
        model.base_model.update_and_allocate(global_step)
        global_step += 1
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_dataloader)}")

    # Evaluate
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            inputs, labels = batch["image"].to(device), batch["label"].to(device)
            outputs = model(pixel_values=inputs)
            predictions = outputs.logits.argmax(dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    print(f"Validation Accuracy: {correct / total * 100:.2f}%")

# Save the fine-tuned model
model.save_pretrained(f"{model_name_or_path}_adalora_fine_tuned")

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 663,984 || all params: 86,540,345 || trainable%: 0.7673


Training Epoch 1: 100%|██████████| 4735/4735 [07:39<00:00, 10.31it/s]


Epoch 1: Loss = 4.2067440153804725


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.85it/s]


Validation Accuracy: 50.86%


Training Epoch 2: 100%|██████████| 4735/4735 [07:38<00:00, 10.32it/s]


Epoch 2: Loss = 3.7980484324498565


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.85it/s]


Validation Accuracy: 62.47%


Training Epoch 3: 100%|██████████| 4735/4735 [07:38<00:00, 10.32it/s]


Epoch 3: Loss = 3.6365246208819317


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.85it/s]


Validation Accuracy: 67.12%


Training Epoch 4: 100%|██████████| 4735/4735 [07:38<00:00, 10.32it/s]


Epoch 4: Loss = 3.543002602970965


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.84it/s]


Validation Accuracy: 70.24%


Training Epoch 5: 100%|██████████| 4735/4735 [07:38<00:00, 10.33it/s]


Epoch 5: Loss = 3.482455226511487


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.86it/s]


Validation Accuracy: 71.79%


Training Epoch 6: 100%|██████████| 4735/4735 [07:38<00:00, 10.32it/s]


Epoch 6: Loss = 3.442574838600038


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.87it/s]


Validation Accuracy: 72.93%


Training Epoch 7: 100%|██████████| 4735/4735 [07:38<00:00, 10.32it/s]


Epoch 7: Loss = 3.4176211204045424


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.85it/s]


Validation Accuracy: 73.10%


Training Epoch 8: 100%|██████████| 4735/4735 [07:39<00:00, 10.30it/s]


Epoch 8: Loss = 3.4043755804221005


Evaluating: 100%|██████████| 1579/1579 [01:15<00:00, 20.84it/s]


Validation Accuracy: 73.60%
