# DEIT full fine tune CIFAR-10

In [1]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from transformers import DeiTForImageClassification
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from torch.utils.data import Subset

start = time.time()
print('Program starts...')
print("Running Classifier-Only Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)")

# Set seeds
np.random.seed(78)
torch.manual_seed(78)

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split training set into train and validation (80/20)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
np.random.shuffle(indices)
split = int(np.floor(0.2 * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
assert len(set(train_indices) & set(val_indices)) == 0, "Train-validation overlap detected"

train_sampler = SubsetRandomSampler(train_indices)
val_dataset = Subset(train_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load DeiT and modify classifier
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)

# Freeze all parameters except classifier
for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

# Verify trainable parameters
print("Trainable parameters before training:")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable params: {trainable_params}")

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_probs = []
    all_labels = []
    all_logits = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_logits.append(outputs.cpu())
            all_labels.append(labels.cpu().numpy())
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = val_loss / len(loader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy, np.concatenate(all_probs), np.concatenate(all_labels), torch.cat(all_logits)

# Compute ECE
def compute_ece(probs, labels, n_bins=10):
    probs = np.clip(probs, 1e-5, 1-1e-5)
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = predictions == labels
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += prop_in_bin * np.abs(avg_confidence_in_bin - accuracy_in_bin)
    return ece

# CutMix function
def cutmix(images, labels, alpha=1.0):
    batch_size = images.size(0)
    indices = torch.randperm(batch_size)
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]
    
    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bby1:bby2, bbx1:bbx2] = shuffled_images[:, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(-1) * images.size(-2)))
    return images, labels, shuffled_labels, lam

def rand_bbox(size, lam):
    W = size[3]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=1, eta_min=1e-6)

# Training loop
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_epoch = 1
best_model_path = "deit_cifar10_classifier_only_cutmix_best_seed78.pt"

model.train()
for epoch in range(10):
    start_time = time.time()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    if epoch < 5:
        lr = 5e-4 * (epoch + 1) / 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        if np.random.rand() < 0.5:
            images, labels_a, labels_b, lam = cutmix(images, labels, alpha=1.0)
            outputs = model(images).logits
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    
    scheduler.step(epoch + 1)
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    
    val_loss, val_accuracy, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
    epoch_time = time.time() - start_time
    
    train_losses.append(avg_train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    if (epoch + 1) % 5 == 0:
        test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {best_epoch} with Val Accuracy: {best_val_accuracy:.2f}%")
    
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%, Time: {epoch_time:.2f}s, LR: {scheduler.get_last_lr()[0]:.6f}")

# Plot metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('classifier_only_cutmix_metrics_cifar10_seed78.png')
plt.close()

# Load best model for final evaluation
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False
    else:
        param.requires_grad = True
model.load_state_dict(torch.load(best_model_path))
model.to(device)

# Evaluate test set with calibration
test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
ece = compute_ece(test_probs, test_labels)

# Temperature scaling
def find_optimal_temperature(val_logits, val_labels):
    def ece_with_temp(temp):
        scaled_probs = F.softmax(val_logits / temp, dim=1).numpy()
        scaled_probs = np.clip(scaled_probs, 1e-5, 1-1e-5)
        return compute_ece(scaled_probs, val_labels)
    
    temps = np.linspace(0.1, 5.0, 20)
    eces = [ece_with_temp(t) for t in temps]
    return temps[np.argmin(eces)]

_, _, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
optimal_temp = find_optimal_temperature(val_logits, val_labels)
scaled_test_probs = F.softmax(test_logits / optimal_temp, dim=1).numpy()
scaled_test_accuracy = accuracy_score(test_labels, np.argmax(scaled_test_probs, axis=1)) * 100
scaled_ece = compute_ece(scaled_test_probs, test_labels)

class_accuracies = []
for i in range(10):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(test_labels[class_mask], np.argmax(test_probs[class_mask], axis=1))
        class_accuracies.append(class_acc)
class_acc_mean = np.mean(class_accuracies)
class_acc_std = np.std(class_accuracies)

print(f"Best Model (Epoch {best_epoch}) - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print(f"ECE: {ece:.4f}, Scaled ECE: {scaled_ece:.4f}, Scaled Test Accuracy: {scaled_test_accuracy:.2f}%")
print(f"Class-wise Accuracy: Mean {class_acc_mean:.2f}, Std {class_acc_std:.2f}")
print(f"Total training time: {(time.time() - start):.2f} seconds")

# Save final model
torch.save(model.state_dict(), "deit_cifar10_classifier_only_cutmix_final_seed78.pt")

Program starts...
Running Classifier-Only Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 91460144.39it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before training:
Total trainable params: 7690
New best model saved at epoch 1 with Val Accuracy: 94.13%
Epoch 1, Train Loss: 1.1643, Val Loss: 0.6939, Train Accuracy: 76.69%, Val Accuracy: 94.13%, Time: 100.80s, LR: 0.000499
New best model saved at epoch 2 with Val Accuracy: 95.22%
Epoch 2, Train Loss: 0.9845, Val Loss: 0.6637, Train Accuracy: 81.86%, Val Accuracy: 95.22%, Time: 98.87s, LR: 0.000495
New best model saved at epoch 3 with Val Accuracy: 95.46%
Epoch 3, Train Loss: 0.9465, Val Loss: 0.6524, Train Accuracy: 84.70%, Val Accuracy: 95.46%, Time: 99.91s, LR: 0.000488
New best model saved at epoch 4 with Val Accuracy: 95.69%
Epoch 4, Train Loss: 0.9258, Val Loss: 0.6497, Train Accuracy: 84.20%, Val Accuracy: 95.69%, Time: 99.26s, LR: 0.000478
Epoch 5 - Test Loss: 0.6576, Test Accuracy: 95.24%
Epoch 5, Train Loss: 0.9424, Val Loss: 0.6512, Train Accuracy: 83.56%, Val Accuracy: 95.61%, Time: 99.90s, LR: 0.000467
Epoch 6, Train Loss: 0.9433, Val Loss: 0.6475, Tr

Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Model (Epoch 10) - Test Loss: 0.6525, Test Accuracy: 95.26%
ECE: 0.1274, Scaled ECE: 0.0228, Scaled Test Accuracy: 95.26%
Class-wise Accuracy: Mean 0.95, Std 0.02
Total training time: 1088.29 seconds


# FULL FT start LR 5e-4

In [3]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from transformers import DeiTForImageClassification
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from torch.utils.data import Subset
from tqdm import tqdm

start = time.time()
print('Program starts...')
print("Running Full Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)")

# Set seeds
np.random.seed(78)
torch.manual_seed(78)

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split training set into train and validation (80/20)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
np.random.shuffle(indices)
split = int(np.floor(0.2 * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
assert len(set(train_indices) & set(val_indices)) == 0, "Train-validation overlap detected"

train_sampler = SubsetRandomSampler(train_indices)
val_dataset = Subset(train_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load DeiT and modify classifier
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)

# All parameters are trainable for full fine-tuning
for param in model.parameters():
    param.requires_grad = True

# Verify trainable parameters
print("Trainable parameters before training:")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable params: {trainable_params}")

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_probs = []
    all_labels = []
    all_logits = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_logits.append(outputs.cpu())
            all_labels.append(labels.cpu().numpy())
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = val_loss / len(loader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy, np.concatenate(all_probs), np.concatenate(all_labels), torch.cat(all_logits)

# Compute ECE
def compute_ece(probs, labels, n_bins=10):
    probs = np.clip(probs, 1e-5, 1-1e-5)
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = predictions == labels
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += prop_in_bin * np.abs(avg_confidence_in_bin - accuracy_in_bin)
    return ece

# CutMix function
def cutmix(images, labels, alpha=1.0):
    batch_size = images.size(0)
    indices = torch.randperm(batch_size)
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]
    
    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bby1:bby2, bbx1:bbx2] = shuffled_images[:, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(-1) * images.size(-2)))
    return images, labels, shuffled_labels, lam

def rand_bbox(size, lam):
    W = size[3]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=1, eta_min=1e-6)

# Training loop
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_epoch = 1
best_model_path = "deit_cifar10_full_finetune_cutmix_best_seed78.pt"

model.train()
for epoch in range(10):
    start_time = time.time()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    if epoch < 5:
        lr = 5e-4 * (epoch + 1) / 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Add tqdm progress bar for training loop
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{10}", leave=True)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        if np.random.rand() < 0.5:
            images, labels_a, labels_b, lam = cutmix(images, labels, alpha=1.0)
            outputs = model(images).logits
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        # Update progress bar with current loss
        progress_bar.set_postfix({"loss": loss.item()})
    
    progress_bar.close()
    scheduler.step(epoch + 1)
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    
    val_loss, val_accuracy, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
    epoch_time = time.time() - start_time
    
    train_losses.append(avg_train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    if (epoch + 1) % 5 == 0:
        test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {best_epoch} with Val Accuracy: {best_val_accuracy:.2f}%")
    
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%, Time: {epoch_time:.2f}s, LR: {scheduler.get_last_lr()[0]:.6f}")

# Plot metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('full_finetune_cutmix_metrics_cifar10_seed78.png')
plt.close()

# Load best model for final evaluation
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
model.load_state_dict(torch.load(best_model_path))
model.to(device)

# Evaluate test set with calibration
test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
ece = compute_ece(test_probs, test_labels)

# Temperature scaling
def find_optimal_temperature(val_logits, val_labels):
    def ece_with_temp(temp):
        scaled_probs = F.softmax(val_logits / temp, dim=1).numpy()
        scaled_probs = np.clip(scaled_probs, 1e-5, 1-1e-5)
        return compute_ece(scaled_probs, val_labels)
    
    temps = np.linspace(0.1, 5.0, 20)
    eces = [ece_with_temp(t) for t in temps]
    return temps[np.argmin(eces)]

_, _, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
optimal_temp = find_optimal_temperature(val_logits, val_labels)
scaled_test_probs = F.softmax(test_logits / optimal_temp, dim=1).numpy()
scaled_test_accuracy = accuracy_score(test_labels, np.argmax(scaled_test_probs, axis=1)) * 100
scaled_ece = compute_ece(scaled_test_probs, test_labels)

class_accuracies = []
for i in range(10):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(test_labels[class_mask], np.argmax(test_probs[class_mask], axis=1))
        class_accuracies.append(class_acc)
class_acc_mean = np.mean(class_accuracies)
class_acc_std = np.std(class_accuracies)

print(f"Best Model (Epoch {best_epoch}) - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print(f"ECE: {ece:.4f}, Scaled ECE: {scaled_ece:.4f}, Scaled Test Accuracy: {scaled_test_accuracy:.2f}%")
print(f"Class-wise Accuracy: Mean {class_acc_mean:.2f}, Std {class_acc_std:.2f}")
print(f"Total training time: {(time.time() - start):.2f} seconds")

# Save final model
torch.save(model.state_dict(), "deit_cifar10_full_finetune_cutmix_final_seed78.pt")

Program starts...
Running Full Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)
Files already downloaded and verified
Files already downloaded and verified


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before training:
Total trainable params: 85807882


Epoch 1/10: 100%|██████████| 1250/1250 [03:15<00:00,  6.38it/s, loss=0.635]


New best model saved at epoch 1 with Val Accuracy: 95.62%
Epoch 1, Train Loss: 0.8902, Val Loss: 0.6104, Train Accuracy: 85.06%, Val Accuracy: 95.62%, Time: 217.26s, LR: 0.000499


Epoch 2/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.34it/s, loss=0.544]


Epoch 2, Train Loss: 0.9066, Val Loss: 0.6239, Train Accuracy: 83.39%, Val Accuracy: 94.63%, Time: 218.36s, LR: 0.000495


Epoch 3/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.34it/s, loss=0.651]


Epoch 3, Train Loss: 0.9291, Val Loss: 0.6661, Train Accuracy: 83.78%, Val Accuracy: 92.67%, Time: 217.88s, LR: 0.000488


Epoch 4/10: 100%|██████████| 1250/1250 [03:13<00:00,  6.45it/s, loss=0.664]


Epoch 4, Train Loss: 0.9610, Val Loss: 0.7175, Train Accuracy: 81.23%, Val Accuracy: 91.01%, Time: 214.84s, LR: 0.000478


Epoch 5/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.30it/s, loss=0.94] 


Epoch 5 - Test Loss: 0.6960, Test Accuracy: 91.59%
Epoch 5, Train Loss: 1.0133, Val Loss: 0.6935, Train Accuracy: 78.81%, Val Accuracy: 91.77%, Time: 219.39s, LR: 0.000467


Epoch 6/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.34it/s, loss=0.737]


Epoch 6, Train Loss: 0.9771, Val Loss: 0.7191, Train Accuracy: 80.28%, Val Accuracy: 91.26%, Time: 218.55s, LR: 0.000452


Epoch 7/10: 100%|██████████| 1250/1250 [03:15<00:00,  6.40it/s, loss=0.638]


Epoch 7, Train Loss: 0.9449, Val Loss: 0.7458, Train Accuracy: 82.06%, Val Accuracy: 89.47%, Time: 216.24s, LR: 0.000436


Epoch 8/10: 100%|██████████| 1250/1250 [03:15<00:00,  6.40it/s, loss=1.06] 


Epoch 8, Train Loss: 0.9210, Val Loss: 0.6843, Train Accuracy: 84.17%, Val Accuracy: 92.28%, Time: 216.18s, LR: 0.000417


Epoch 9/10: 100%|██████████| 1250/1250 [03:15<00:00,  6.38it/s, loss=0.573]


Epoch 9, Train Loss: 0.8894, Val Loss: 0.6784, Train Accuracy: 84.44%, Val Accuracy: 92.62%, Time: 216.81s, LR: 0.000397


Epoch 10/10: 100%|██████████| 1250/1250 [03:16<00:00,  6.37it/s, loss=0.516]


Epoch 10 - Test Loss: 0.6914, Test Accuracy: 92.70%
Epoch 10, Train Loss: 0.8663, Val Loss: 0.6860, Train Accuracy: 85.45%, Val Accuracy: 92.93%, Time: 217.21s, LR: 0.000375


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Model (Epoch 1) - Test Loss: 0.6158, Test Accuracy: 95.33%
ECE: 0.0726, Scaled ECE: 0.0215, Scaled Test Accuracy: 95.33%
Class-wise Accuracy: Mean 0.95, Std 0.04
Total training time: 2260.39 seconds


# FULL FT start LR 1e-4

In [4]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from transformers import DeiTForImageClassification
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from torch.utils.data import Subset
from tqdm import tqdm

start = time.time()
print('Program starts...')
print("Running Full Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)")

# Set seeds
np.random.seed(78)
torch.manual_seed(78)

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split training set into train and validation (80/20)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
np.random.shuffle(indices)
split = int(np.floor(0.2 * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
assert len(set(train_indices) & set(val_indices)) == 0, "Train-validation overlap detected"

train_sampler = SubsetRandomSampler(train_indices)
val_dataset = Subset(train_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load DeiT and modify classifier
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)

# All parameters are trainable for full fine-tuning
for param in model.parameters():
    param.requires_grad = True

# Verify trainable parameters
print("Trainable parameters before training:")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable params: {trainable_params}")

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_probs = []
    all_labels = []
    all_logits = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_logits.append(outputs.cpu())
            all_labels.append(labels.cpu().numpy())
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = val_loss / len(loader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy, np.concatenate(all_probs), np.concatenate(all_labels), torch.cat(all_logits)

# Compute ECE
def compute_ece(probs, labels, n_bins=10):
    probs = np.clip(probs, 1e-5, 1-1e-5)
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = predictions == labels
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += prop_in_bin * np.abs(avg_confidence_in_bin - accuracy_in_bin)
    return ece

# CutMix function
def cutmix(images, labels, alpha=1.0):
    batch_size = images.size(0)
    indices = torch.randperm(batch_size)
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]
    
    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bby1:bby2, bbx1:bbx2] = shuffled_images[:, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(-1) * images.size(-2)))
    return images, labels, shuffled_labels, lam

def rand_bbox(size, lam):
    W = size[3]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=1, eta_min=1e-6)

# Training loop
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_epoch = 1
best_model_path = "deit_cifar10_full_finetune_cutmix_best_seed78_slowLR.pt"

model.train()
for epoch in range(10):
    start_time = time.time()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    if epoch < 5:
        lr = 1e-4 * (epoch + 1) / 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Add tqdm progress bar for training loop
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{10}", leave=True)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        if np.random.rand() < 0.5:
            images, labels_a, labels_b, lam = cutmix(images, labels, alpha=1.0)
            outputs = model(images).logits
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        # Update progress bar with current loss
        progress_bar.set_postfix({"loss": loss.item()})
    
    progress_bar.close()
    scheduler.step(epoch + 1)
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    
    val_loss, val_accuracy, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
    epoch_time = time.time() - start_time
    
    train_losses.append(avg_train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    if (epoch + 1) % 5 == 0:
        test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {best_epoch} with Val Accuracy: {best_val_accuracy:.2f}%")
    
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%, Time: {epoch_time:.2f}s, LR: {scheduler.get_last_lr()[0]:.6f}")

# Plot metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('full_finetune_cutmix_metrics_cifar10_seed78_slowLR.png')
plt.close()

# Load best model for final evaluation
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
model.load_state_dict(torch.load(best_model_path))
model.to(device)

# Evaluate test set with calibration
test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
ece = compute_ece(test_probs, test_labels)

# Temperature scaling
def find_optimal_temperature(val_logits, val_labels):
    def ece_with_temp(temp):
        scaled_probs = F.softmax(val_logits / temp, dim=1).numpy()
        scaled_probs = np.clip(scaled_probs, 1e-5, 1-1e-5)
        return compute_ece(scaled_probs, val_labels)
    
    temps = np.linspace(0.1, 5.0, 20)
    eces = [ece_with_temp(t) for t in temps]
    return temps[np.argmin(eces)]

_, _, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
optimal_temp = find_optimal_temperature(val_logits, val_labels)
scaled_test_probs = F.softmax(test_logits / optimal_temp, dim=1).numpy()
scaled_test_accuracy = accuracy_score(test_labels, np.argmax(scaled_test_probs, axis=1)) * 100
scaled_ece = compute_ece(scaled_test_probs, test_labels)

class_accuracies = []
for i in range(10):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(test_labels[class_mask], np.argmax(test_probs[class_mask], axis=1))
        class_accuracies.append(class_acc)
class_acc_mean = np.mean(class_accuracies)
class_acc_std = np.std(class_accuracies)

print(f"Best Model (Epoch {best_epoch}) - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print(f"ECE: {ece:.4f}, Scaled ECE: {scaled_ece:.4f}, Scaled Test Accuracy: {scaled_test_accuracy:.2f}%")
print(f"Class-wise Accuracy: Mean {class_acc_mean:.2f}, Std {class_acc_std:.2f}")
print(f"Total training time: {(time.time() - start):.2f} seconds")

# Save final model
torch.save(model.state_dict(), "deit_cifar10_full_finetune_cutmix_final_seed78_slowLR.pt")

Program starts...
Running Full Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)
Files already downloaded and verified
Files already downloaded and verified


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before training:
Total trainable params: 85807882


Epoch 1/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.32it/s, loss=0.517]


New best model saved at epoch 1 with Val Accuracy: 97.86%
Epoch 1, Train Loss: 0.8796, Val Loss: 0.5542, Train Accuracy: 85.75%, Val Accuracy: 97.86%, Time: 218.50s, LR: 0.000100


Epoch 2/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.31it/s, loss=0.533]


New best model saved at epoch 2 with Val Accuracy: 97.93%
Epoch 2, Train Loss: 0.8114, Val Loss: 0.5542, Train Accuracy: 87.19%, Val Accuracy: 97.93%, Time: 219.05s, LR: 0.000099


Epoch 3/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.34it/s, loss=0.567]


Epoch 3, Train Loss: 0.7837, Val Loss: 0.5625, Train Accuracy: 89.86%, Val Accuracy: 97.68%, Time: 218.43s, LR: 0.000098


Epoch 4/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.31it/s, loss=0.508]


Epoch 4, Train Loss: 0.7745, Val Loss: 0.5657, Train Accuracy: 88.67%, Val Accuracy: 97.35%, Time: 218.81s, LR: 0.000096


Epoch 5/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.33it/s, loss=0.797]


Epoch 5 - Test Loss: 0.5759, Test Accuracy: 97.16%
Epoch 5, Train Loss: 0.7977, Val Loss: 0.5743, Train Accuracy: 87.64%, Val Accuracy: 97.12%, Time: 218.59s, LR: 0.000093


Epoch 6/10: 100%|██████████| 1250/1250 [03:16<00:00,  6.36it/s, loss=0.507]


Epoch 6, Train Loss: 0.7815, Val Loss: 0.5871, Train Accuracy: 88.20%, Val Accuracy: 96.70%, Time: 217.14s, LR: 0.000091


Epoch 7/10: 100%|██████████| 1250/1250 [03:14<00:00,  6.42it/s, loss=0.564]


Epoch 7, Train Loss: 0.7695, Val Loss: 0.5696, Train Accuracy: 89.35%, Val Accuracy: 97.41%, Time: 215.81s, LR: 0.000087


Epoch 8/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.31it/s, loss=0.942]


Epoch 8, Train Loss: 0.7668, Val Loss: 0.5636, Train Accuracy: 90.98%, Val Accuracy: 97.71%, Time: 218.97s, LR: 0.000084


Epoch 9/10: 100%|██████████| 1250/1250 [03:16<00:00,  6.36it/s, loss=0.501]


Epoch 9, Train Loss: 0.7532, Val Loss: 0.5672, Train Accuracy: 89.81%, Val Accuracy: 97.46%, Time: 217.84s, LR: 0.000080


Epoch 10/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.29it/s, loss=0.501]


Epoch 10 - Test Loss: 0.5603, Test Accuracy: 97.85%
Epoch 10, Train Loss: 0.7469, Val Loss: 0.5643, Train Accuracy: 90.18%, Val Accuracy: 97.71%, Time: 220.07s, LR: 0.000075


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Model (Epoch 2) - Test Loss: 0.5547, Test Accuracy: 97.87%
ECE: 0.0891, Scaled ECE: 0.0051, Scaled Test Accuracy: 97.87%
Class-wise Accuracy: Mean 0.98, Std 0.01
Total training time: 2271.24 seconds


# FULL FT start LR 1e-5

In [5]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from transformers import DeiTForImageClassification
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from torch.utils.data import Subset
from tqdm import tqdm

start = time.time()
print('Program starts...')
print("Running Full Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)")

# Set seeds
np.random.seed(78)
torch.manual_seed(78)

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split training set into train and validation (80/20)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
np.random.shuffle(indices)
split = int(np.floor(0.2 * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
assert len(set(train_indices) & set(val_indices)) == 0, "Train-validation overlap detected"

train_sampler = SubsetRandomSampler(train_indices)
val_dataset = Subset(train_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load DeiT and modify classifier
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)

# All parameters are trainable for full fine-tuning
for param in model.parameters():
    param.requires_grad = True

# Verify trainable parameters
print("Trainable parameters before training:")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable params: {trainable_params}")

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_probs = []
    all_labels = []
    all_logits = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_logits.append(outputs.cpu())
            all_labels.append(labels.cpu().numpy())
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = val_loss / len(loader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy, np.concatenate(all_probs), np.concatenate(all_labels), torch.cat(all_logits)

# Compute ECE
def compute_ece(probs, labels, n_bins=10):
    probs = np.clip(probs, 1e-5, 1-1e-5)
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = predictions == labels
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += prop_in_bin * np.abs(avg_confidence_in_bin - accuracy_in_bin)
    return ece

# CutMix function
def cutmix(images, labels, alpha=1.0):
    batch_size = images.size(0)
    indices = torch.randperm(batch_size)
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]
    
    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bby1:bby2, bbx1:bbx2] = shuffled_images[:, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(-1) * images.size(-2)))
    return images, labels, shuffled_labels, lam

def rand_bbox(size, lam):
    W = size[3]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=1, eta_min=1e-6)

# Training loop
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_epoch = 1
best_model_path = "deit_cifar10_full_finetune_cutmix_best_seed78_slowerLR.pt"

model.train()
for epoch in range(10):
    start_time = time.time()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    if epoch < 5:
        lr = 1e-5 * (epoch + 1) / 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Add tqdm progress bar for training loop
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{10}", leave=True)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        if np.random.rand() < 0.5:
            images, labels_a, labels_b, lam = cutmix(images, labels, alpha=1.0)
            outputs = model(images).logits
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        # Update progress bar with current loss
        progress_bar.set_postfix({"loss": loss.item()})
    
    progress_bar.close()
    scheduler.step(epoch + 1)
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    
    val_loss, val_accuracy, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
    epoch_time = time.time() - start_time
    
    train_losses.append(avg_train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    if (epoch + 1) % 5 == 0:
        test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {best_epoch} with Val Accuracy: {best_val_accuracy:.2f}%")
    
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%, Time: {epoch_time:.2f}s, LR: {scheduler.get_last_lr()[0]:.6f}")

# Plot metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('full_finetune_cutmix_metrics_cifar10_seed78_slowerLR.png')
plt.close()

# Load best model for final evaluation
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)
model.load_state_dict(torch.load(best_model_path))
model.to(device)

# Evaluate test set with calibration
test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
ece = compute_ece(test_probs, test_labels)

# Temperature scaling
def find_optimal_temperature(val_logits, val_labels):
    def ece_with_temp(temp):
        scaled_probs = F.softmax(val_logits / temp, dim=1).numpy()
        scaled_probs = np.clip(scaled_probs, 1e-5, 1-1e-5)
        return compute_ece(scaled_probs, val_labels)
    
    temps = np.linspace(0.1, 5.0, 20)
    eces = [ece_with_temp(t) for t in temps]
    return temps[np.argmin(eces)]

_, _, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
optimal_temp = find_optimal_temperature(val_logits, val_labels)
scaled_test_probs = F.softmax(test_logits / optimal_temp, dim=1).numpy()
scaled_test_accuracy = accuracy_score(test_labels, np.argmax(scaled_test_probs, axis=1)) * 100
scaled_ece = compute_ece(scaled_test_probs, test_labels)

class_accuracies = []
for i in range(10):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(test_labels[class_mask], np.argmax(test_probs[class_mask], axis=1))
        class_accuracies.append(class_acc)
class_acc_mean = np.mean(class_accuracies)
class_acc_std = np.std(class_accuracies)

print(f"Best Model (Epoch {best_epoch}) - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print(f"ECE: {ece:.4f}, Scaled ECE: {scaled_ece:.4f}, Scaled Test Accuracy: {scaled_test_accuracy:.2f}%")
print(f"Class-wise Accuracy: Mean {class_acc_mean:.2f}, Std {class_acc_std:.2f}")
print(f"Total training time: {(time.time() - start):.2f} seconds")

# Save final model
torch.save(model.state_dict(), "deit_cifar10_full_finetune_cutmix_final_seed78_slowerLR.pt")

Program starts...
Running Full Fine-Tuning with CutMix on CIFAR-10 (15 Epochs)
Files already downloaded and verified
Files already downloaded and verified


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before training:
Total trainable params: 85807882


Epoch 1/10: 100%|██████████| 1250/1250 [03:19<00:00,  6.28it/s, loss=0.585]


New best model saved at epoch 1 with Val Accuracy: 96.55%
Epoch 1, Train Loss: 1.2312, Val Loss: 0.6105, Train Accuracy: 72.90%, Val Accuracy: 96.55%, Time: 220.34s, LR: 0.000010


Epoch 2/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.33it/s, loss=0.53] 


New best model saved at epoch 2 with Val Accuracy: 97.66%
Epoch 2, Train Loss: 0.8684, Val Loss: 0.5623, Train Accuracy: 85.44%, Val Accuracy: 97.66%, Time: 218.50s, LR: 0.000010


Epoch 3/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.30it/s, loss=0.574]


New best model saved at epoch 3 with Val Accuracy: 98.20%
Epoch 3, Train Loss: 0.8018, Val Loss: 0.5479, Train Accuracy: 89.33%, Val Accuracy: 98.20%, Time: 219.25s, LR: 0.000010


Epoch 4/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.31it/s, loss=0.509]


New best model saved at epoch 4 with Val Accuracy: 98.32%
Epoch 4, Train Loss: 0.7647, Val Loss: 0.5413, Train Accuracy: 89.22%, Val Accuracy: 98.32%, Time: 219.18s, LR: 0.000010


Epoch 5/10: 100%|██████████| 1250/1250 [03:18<00:00,  6.30it/s, loss=0.801]


Epoch 5 - Test Loss: 0.5450, Test Accuracy: 98.20%
New best model saved at epoch 5 with Val Accuracy: 98.49%
Epoch 5, Train Loss: 0.7675, Val Loss: 0.5391, Train Accuracy: 88.92%, Val Accuracy: 98.49%, Time: 219.43s, LR: 0.000009


Epoch 6/10: 100%|██████████| 1250/1250 [03:15<00:00,  6.39it/s, loss=0.502]


New best model saved at epoch 6 with Val Accuracy: 98.75%
Epoch 6, Train Loss: 0.7596, Val Loss: 0.5346, Train Accuracy: 89.03%, Val Accuracy: 98.75%, Time: 216.67s, LR: 0.000009


Epoch 7/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.33it/s, loss=0.563]


Epoch 7, Train Loss: 0.7503, Val Loss: 0.5375, Train Accuracy: 89.87%, Val Accuracy: 98.58%, Time: 218.58s, LR: 0.000009


Epoch 8/10: 100%|██████████| 1250/1250 [03:19<00:00,  6.28it/s, loss=0.88] 


Epoch 8, Train Loss: 0.7496, Val Loss: 0.5374, Train Accuracy: 91.47%, Val Accuracy: 98.59%, Time: 220.17s, LR: 0.000009


Epoch 9/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.33it/s, loss=0.501]


New best model saved at epoch 9 with Val Accuracy: 98.76%
Epoch 9, Train Loss: 0.7371, Val Loss: 0.5334, Train Accuracy: 90.50%, Val Accuracy: 98.76%, Time: 218.40s, LR: 0.000008


Epoch 10/10: 100%|██████████| 1250/1250 [03:17<00:00,  6.34it/s, loss=0.501]


Epoch 10 - Test Loss: 0.5386, Test Accuracy: 98.52%
Epoch 10, Train Loss: 0.7330, Val Loss: 0.5360, Train Accuracy: 90.72%, Val Accuracy: 98.68%, Time: 218.17s, LR: 0.000008


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Model (Epoch 9) - Test Loss: 0.5381, Test Accuracy: 98.51%
ECE: 0.0840, Scaled ECE: 0.0051, Scaled Test Accuracy: 98.51%
Class-wise Accuracy: Mean 0.99, Std 0.01
Total training time: 2278.70 seconds
