# Attention only FC-LoRA Cifar-10

In [2]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
from transformers import DeiTForImageClassification
from peft import LoraConfig, get_peft_model
import torchvision
import torchvision.transforms as transforms
import numpy as np
from collections import defaultdict
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from tqdm import tqdm

start = time.time()
np.random.seed(78)
torch.manual_seed(78)
BATCH_SIZE = 32

# DATA
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
indices = np.arange(len(train_dataset))
np.random.shuffle(indices)
split = int(0.2 * len(train_dataset))
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
val_dataset = Subset(train_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# MODEL
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
model.classifier.requires_grad_(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# VALIDATION, ECE, TEMP SCALING
def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    all_probs, all_labels, all_logits = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_logits.append(outputs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = val_loss / len(loader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy, np.concatenate(all_probs), np.concatenate(all_labels), np.concatenate(all_logits)

def compute_ece(probs, labels, n_bins=15):
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = (predictions == labels)
    ece = 0.0
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    for bin_lower, bin_upper in zip(bin_boundaries[:-1], bin_boundaries[1:]):
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece

def find_optimal_temperature(model, loader, device):
    model.eval()
    all_logits, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            all_logits.append(model(images).logits.cpu())
            all_labels.append(labels.cpu())
    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels).numpy()
    temperatures = np.linspace(0.2, 3.0, 100)
    best_temp, min_ece = 1.0, float('inf')
    for t in temperatures:
        scaled_probs = F.softmax(all_logits / t, dim=1).numpy()
        ece = compute_ece(scaled_probs, all_labels)
        if ece < min_ece:
            min_ece, best_temp = ece, t
    return best_temp

# ADAPTIVE RANK SELECTION 
fisher_sums = defaultdict(float)
grad_sums = defaultdict(float)
mean_out = defaultdict(float)
sq_mean_out = defaultdict(float)
n_batches = 0
hook_handles = []
target_layers = []

def hook_fn(name):
    def hook(module, input, output):
        output_tensor = output[0] if isinstance(output, tuple) else output
        batch_out_flat = output_tensor.detach().view(-1, output_tensor.shape[-1])
        mean_out[name] += batch_out_flat.mean(dim=0)
        sq_mean_out[name] += (batch_out_flat ** 2).mean(dim=0)
    return hook

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and any(t in name for t in ["query", "key", "value"]):
        target_layers.append(name)
        handle = module.register_forward_hook(hook_fn(name))
        hook_handles.append(handle)

temp_optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for i, (images, labels) in enumerate(tqdm(train_loader, desc="Compute Importance", leave=False), 1):
    if i > 100: break
    images, labels = images.to(device), labels.to(device)
    temp_optimizer.zero_grad()
    loss = criterion(model(images).logits, labels)
    loss.backward()
    n_batches += 1
    for name, module in model.named_modules():
        if name in target_layers:
            for p_name, p in [("weight", module.weight), ("bias", getattr(module, 'bias', None))]:
                if p is not None and p.grad is not None:
                    fisher_sums[f"{name}.{p_name}"] += (p.grad ** 2).mean().item()
                    grad_sums[f"{name}.{p_name}"] += p.grad.abs().mean().item()
for handle in hook_handles: handle.remove()

cov_trace = {name: ((sq_mean_out[name] / n_batches) - ((mean_out[name] / n_batches) ** 2)).sum().item() for name in mean_out}
combined_importance = {}
fisher_min, fisher_max = min(fisher_sums.values()), max(fisher_sums.values())
grad_min, grad_max = min(grad_sums.values()), max(grad_sums.values())
cov_min, cov_max = (min(cov_trace.values()), max(cov_trace.values())) if cov_trace else (0, 0)
for name in target_layers:
    fisher_score = sum(fisher_sums.get(p, 0) for p in [f"{name}.weight", f"{name}.bias"])
    grad_score = sum(grad_sums.get(p, 0) for p in [f"{name}.weight", f"{name}.bias"])
    cov_score = cov_trace.get(name, 0)
    fisher_z = (fisher_score - fisher_min) / (fisher_max - fisher_min + 1e-6)
    grad_z = (grad_score - grad_min) / (grad_max - grad_min + 1e-6)
    cov_z = (cov_score - cov_min) / (cov_max - cov_min + 1e-6) if cov_trace else 0
    score = 0.6 * fisher_z + 0.2 * grad_z + 0.2 * cov_z
    combined_importance[name] = score

r_max, r_min, r_total = 16, 1, 50
adaptive_ranks = {}
sorted_layers = sorted(combined_importance.items(), key=lambda x: x[1], reverse=True)
remaining_budget = r_total
for name, score in sorted_layers:
    if remaining_budget <= 0:
        adaptive_ranks[name] = r_min
    else:
        rank = min(max(r_min, int(score * r_total)), r_max, remaining_budget)
        adaptive_ranks[name] = rank
        remaining_budget -= rank
if remaining_budget > 0:
    for name in [n for n, s in sorted_layers][:remaining_budget]:
        adaptive_ranks[name] = min(adaptive_ranks.get(name, 0) + 1, r_max)
print("\nAdaptive Ranks (Attention Only):", adaptive_ranks)

# PEFT INJECTION + MANUAL RANK ADJUSTMENT 
lora_config = LoraConfig(
    r=r_max,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=list(adaptive_ranks.keys()),
)
model = get_peft_model(model, lora_config)
model.to(device)

# Manually set per-layer rank
for name, rank in adaptive_ranks.items():
    module = model.get_submodule(name)
    if hasattr(module, 'r'):
        if isinstance(module.r, dict) and 'default' in module.r:
            module.r['default'] = rank
        else:
            module.r = rank
        in_features = module.in_features
        out_features = module.out_features
        module.lora_A['default'].weight.data = torch.randn(rank, in_features).to(device) * 0.02
        module.lora_B['default'].weight.data = torch.zeros(out_features, rank).to(device)

print("\nFinal trainable model structure:")
model.print_trainable_parameters()

# TRAINING 
for name, param in model.named_parameters():
    if 'lora_' in name or 'classifier' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-4)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
def cutmix(images, labels, alpha=1.0):
    batch_size = images.size(0)
    indices = torch.randperm(batch_size)
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]
    lam = np.random.beta(alpha, alpha)
    W, H = images.size(3), images.size(2)
    cut_rat = np.sqrt(1. - lam)
    cut_w, cut_h = int(W * cut_rat), int(H * cut_rat)
    cx, cy = np.random.randint(W), np.random.randint(H)
    bbx1, bby1 = np.clip(cx - cut_w // 2, 0, W), np.clip(cy - cut_h // 2, 0, H)
    bbx2, bby2 = np.clip(cx + cut_w // 2, 0, W), np.clip(cy + cut_h // 2, 0, H)
    images[:, :, bby1:bby2, bbx1:bbx2] = shuffled_images[:, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(2) * images.size(3)))
    return images, labels, shuffled_labels, lam

train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
best_val_accuracy, best_epoch = 0.0, 1
best_model_path = "deit_cifar10_hybrid_LoRA_PEFT.pt"

model.train()
for epoch in range(10):
    train_loss, train_correct, train_total = 0.0, 0, 0
    if epoch < 5:
        lr = 5e-4 * (epoch + 1) / 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/10", leave=True)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        if np.random.rand() < 0.5:
            images, labels_a, labels_b, lam = cutmix(images, labels, alpha=1.0)
            outputs = model(images).logits
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
    progress_bar.close()
    scheduler.step()
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    val_loss, val_accuracy, val_probs, val_labels, _ = validate(model, val_loader, criterion, device)
    train_losses.append(avg_train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    if (epoch + 1) % 5 == 0:
        test_loss, test_accuracy, *_ = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1}: Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        torch.save(model.state_dict(), best_model_path)
        print(f"Epoch {epoch+1}: New best model saved with Val Accuracy: {best_val_accuracy:.2f}%")
    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%")

# EVAL/LOGGING 
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('deit_cifar10_hybrid_LoRA_PEFT_fastlr.png')
plt.close()

# FINAL EVAL
base_model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
base_model.classifier = torch.nn.Linear(base_model.classifier.in_features, 10)
model = get_peft_model(base_model, lora_config)
model.to(device)
for name, rank in adaptive_ranks.items():
    module = model.get_submodule(name)
    if hasattr(module, 'r'):
        if isinstance(module.r, dict) and 'default' in module.r:
            module.r['default'] = rank
        else:
            module.r = rank
        in_features = module.in_features
        out_features = module.out_features
        module.lora_A['default'].weight.data = torch.randn(rank, in_features).to(device) * 0.02
        module.lora_B['default'].weight.data = torch.zeros(out_features, rank).to(device)
model.load_state_dict(torch.load(best_model_path))
model.to(device)
test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
ece = compute_ece(test_probs, test_labels)
class_accuracies = []
for i in range(10):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(test_labels[class_mask], np.argmax(test_probs[class_mask], axis=1))
        class_accuracies.append(class_acc)
class_acc_mean = np.mean(class_accuracies)
class_acc_std = np.std(class_accuracies)
optimal_temp = find_optimal_temperature(model, val_loader, device)
scaled_test_probs = F.softmax(torch.tensor(test_logits) / optimal_temp, dim=1).numpy()
scaled_test_accuracy = accuracy_score(test_labels, np.argmax(scaled_test_probs, axis=1)) * 100
scaled_ece = compute_ece(scaled_test_probs, test_labels)

print(f"Best Model (Epoch {best_epoch}): Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print(f"ECE: {ece:.4f}, Scaled ECE: {scaled_ece:.4f}, Scaled Test Accuracy: {scaled_test_accuracy:.2f}%")
print(f"Class-wise Accuracy: Mean {class_acc_mean:.2f}, Std {class_acc_std:.2f}")
print(f"Total training time: {(time.time() - start):.2f} seconds")
# torch.save(model.state_dict(), "deit_cifar10_hybrid_LoRA_PEFT_final.pt")

Files already downloaded and verified
Files already downloaded and verified


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                      


Adaptive Ranks (Attention Only): {'deit.encoder.layer.0.attention.attention.value': 16, 'deit.encoder.layer.9.attention.attention.value': 16, 'deit.encoder.layer.11.attention.attention.value': 16, 'deit.encoder.layer.8.attention.attention.value': 2, 'deit.encoder.layer.4.attention.attention.value': 1, 'deit.encoder.layer.10.attention.attention.value': 1, 'deit.encoder.layer.7.attention.attention.value': 1, 'deit.encoder.layer.3.attention.attention.value': 1, 'deit.encoder.layer.6.attention.attention.value': 1, 'deit.encoder.layer.5.attention.attention.value': 1, 'deit.encoder.layer.2.attention.attention.value': 1, 'deit.encoder.layer.1.attention.attention.value': 1, 'deit.encoder.layer.9.attention.attention.query': 1, 'deit.encoder.layer.8.attention.attention.query': 1, 'deit.encoder.layer.7.attention.attention.query': 1, 'deit.encoder.layer.5.attention.attention.query': 1, 'deit.encoder.layer.3.attention.attention.key': 1, 'deit.encoder.layer.6.attention.attention.query': 1, 'deit.en

Epoch 1/10: 100%|██████████| 1250/1250 [02:34<00:00,  8.09it/s, Loss=0.5312]


Epoch 1: New best model saved with Val Accuracy: 97.23%
Epoch 1: Train Loss: 1.0324, Val Loss: 0.5877, Train Acc: 80.37%, Val Acc: 97.23%


Epoch 2/10: 100%|██████████| 1250/1250 [02:34<00:00,  8.10it/s, Loss=0.5535]


Epoch 2: New best model saved with Val Accuracy: 97.72%
Epoch 2: Train Loss: 0.8545, Val Loss: 0.5654, Train Acc: 85.84%, Val Acc: 97.72%


Epoch 3/10: 100%|██████████| 1250/1250 [02:33<00:00,  8.14it/s, Loss=0.5181]


Epoch 3: New best model saved with Val Accuracy: 97.98%
Epoch 3: Train Loss: 0.8144, Val Loss: 0.5561, Train Acc: 88.75%, Val Acc: 97.98%


Epoch 4/10: 100%|██████████| 1250/1250 [02:30<00:00,  8.28it/s, Loss=0.5924]


Epoch 4: New best model saved with Val Accuracy: 98.15%
Epoch 4: Train Loss: 0.7908, Val Loss: 0.5536, Train Acc: 88.22%, Val Acc: 98.15%


Epoch 5/10: 100%|██████████| 1250/1250 [02:33<00:00,  8.16it/s, Loss=0.8673]


Epoch 5: Test Loss: 0.5504, Test Accuracy: 98.12%
Epoch 5: New best model saved with Val Accuracy: 98.24%
Epoch 5: Train Loss: 0.7990, Val Loss: 0.5505, Train Acc: 87.73%, Val Acc: 98.24%


Epoch 6/10: 100%|██████████| 1250/1250 [02:33<00:00,  8.15it/s, Loss=0.5201]


Epoch 6: New best model saved with Val Accuracy: 98.38%
Epoch 6: Train Loss: 0.7900, Val Loss: 0.5455, Train Acc: 87.82%, Val Acc: 98.38%


Epoch 7/10: 100%|██████████| 1250/1250 [02:37<00:00,  7.94it/s, Loss=0.5643]


Epoch 7: Train Loss: 0.7829, Val Loss: 0.5464, Train Acc: 88.73%, Val Acc: 98.20%


Epoch 8/10: 100%|██████████| 1250/1250 [02:34<00:00,  8.07it/s, Loss=0.8801]


Epoch 8: Train Loss: 0.7795, Val Loss: 0.5450, Train Acc: 90.20%, Val Acc: 98.24%


Epoch 9/10: 100%|██████████| 1250/1250 [02:35<00:00,  8.06it/s, Loss=0.5088]


Epoch 9: Train Loss: 0.7606, Val Loss: 0.5430, Train Acc: 89.37%, Val Acc: 98.33%


Epoch 10/10: 100%|██████████| 1250/1250 [02:35<00:00,  8.05it/s, Loss=0.5071]


Epoch 10: Test Loss: 0.5432, Test Accuracy: 98.27%
Epoch 10: New best model saved with Val Accuracy: 98.48%
Epoch 10: Train Loss: 0.7596, Val Loss: 0.5416, Train Acc: 89.64%, Val Acc: 98.48%


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Model (Epoch 10): Test Loss: 0.5432, Test Accuracy: 98.27%
ECE: 0.0886, Scaled ECE: 0.0045, Scaled Test Accuracy: 98.27%
Class-wise Accuracy: Mean 0.98, Std 0.01
Total training time: 1902.77 seconds
