# DEIT full fine tune CIFAR-100

In [3]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from transformers import DeiTForImageClassification
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from torch.utils.data import Subset
from tqdm import tqdm

start = time.time()
print('Program starts...')
print("Running Full Fine-Tuning with CutMix (30 Epochs)")

# Set seeds
np.random.seed(78)
torch.manual_seed(78)

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Split training set into train and validation (80/20)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
np.random.shuffle(indices)
split = int(np.floor(0.2 * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
assert len(set(train_indices) & set(val_indices)) == 0, "Train-validation overlap detected"

train_sampler = SubsetRandomSampler(train_indices)
val_dataset = Subset(train_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load DeiT and modify classifier
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 100)

# All parameters are trainable for full fine-tuning
for param in model.parameters():
    param.requires_grad = True

# Verify trainable parameters
print("Trainable parameters before training:")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable params: {trainable_params}")

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_probs = []
    all_labels = []
    all_logits = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_logits.append(outputs.cpu())
            all_labels.append(labels.cpu().numpy())
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_val_loss = val_loss / len(loader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy, np.concatenate(all_probs), np.concatenate(all_labels), torch.cat(all_logits)

# Compute ECE
def compute_ece(probs, labels, n_bins=10):
    probs = np.clip(probs, 1e-5, 1-1e-5)
    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = predictions == labels
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += prop_in_bin * np.abs(avg_confidence_in_bin - accuracy_in_bin)
    return ece

# CutMix function
def cutmix(images, labels, alpha=1.0):
    batch_size = images.size(0)
    indices = torch.randperm(batch_size)
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]
    
    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bby1:bby2, bbx1:bbx2] = shuffled_images[:, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size(-1) * images.size(-2)))
    return images, labels, shuffled_labels, lam

def rand_bbox(size, lam):
    W = size[3]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-5)

# Training loop
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
best_epoch = 1
best_model_path = "deit_cifar100_full_finetune_cutmix_best_seed78_ep30_slow_LRpt"

model.train()
for epoch in range(30):
    start_time = time.time()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    if epoch < 5:
        lr = 1e-4 * (epoch + 1) / 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Add tqdm progress bar for training loop
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/30", leave=True)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        if np.random.rand() < 0.5:
            images, labels_a, labels_b, lam = cutmix(images, labels, alpha=1.0)
            outputs = model(images).logits
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        
        # Update progress bar with current loss
        progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
    
    progress_bar.close()
    scheduler.step()
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    
    val_loss, val_accuracy, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
    epoch_time = time.time() - start_time
    
    train_losses.append(avg_train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    if (epoch + 1) % 5 == 0:
        test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
        print(f"Epoch {epoch+1} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    if epoch > 4:
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_epoch = epoch + 1
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved at epoch {best_epoch} with Val Accuracy: {best_val_accuracy:.2f}%")
    
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%, Time: {epoch_time:.2f}s, LR: {scheduler.get_last_lr()[0]:.6f}")

# Plot metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('full_finetune_cutmix_metrics_seed78_ep_30_slow_LR.png')
plt.close()

# Load best model for final evaluation
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 100)
model.load_state_dict(torch.load(best_model_path))
model.to(device)

# Evaluate test set with calibration
test_loss, test_accuracy, test_probs, test_labels, test_logits = validate(model, test_loader, criterion, device)
ece = compute_ece(test_probs, test_labels)

# Temperature scaling
def find_optimal_temperature(val_logits, val_labels):
    def ece_with_temp(temp):
        scaled_probs = F.softmax(val_logits / temp, dim=1).numpy()
        scaled_probs = np.clip(scaled_probs, 1e-5, 1-1e-5)
        return compute_ece(scaled_probs, val_labels)
    
    temps = np.linspace(0.1, 5.0, 20)
    eces = [ece_with_temp(t) for t in temps]
    return temps[np.argmin(eces)]

_, _, val_probs, val_labels, val_logits = validate(model, val_loader, criterion, device)
optimal_temp = find_optimal_temperature(val_logits, val_labels)
scaled_test_probs = F.softmax(test_logits / optimal_temp, dim=1).numpy()
scaled_test_accuracy = accuracy_score(test_labels, np.argmax(scaled_test_probs, axis=1)) * 100
scaled_ece = compute_ece(scaled_test_probs, test_labels)

class_accuracies = []
for i in range(100):
    class_mask = test_labels == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(test_labels[class_mask], np.argmax(test_probs[class_mask], axis=1))
        class_accuracies.append(class_acc)
class_acc_mean = np.mean(class_accuracies)
class_acc_std = np.std(class_accuracies)

print(f"Best Model (Epoch {best_epoch}) - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print(f"ECE: {ece:.4f}, Scaled ECE: {scaled_ece:.4f}, Scaled Test Accuracy: {scaled_test_accuracy:.2f}%")
print(f"Class-wise Accuracy: Mean {class_acc_mean:.2f}, Std {class_acc_std:.2f}")
print(f"Total training time: {(time.time() - start):.2f} seconds")

# Save final model
torch.save(model.state_dict(), "deit_cifar100_full_finetune_cutmix_final_seed78_ep30_slow_lr.pt")

Program starts...
Running Full Fine-Tuning with CutMix (30 Epochs)
Files already downloaded and verified
Files already downloaded and verified


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before training:
Total trainable params: 85877092


Epoch 1/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=1.0811]


Epoch 1, Train Loss: 2.2618, Val Loss: 1.2296, Train Accuracy: 61.52%, Val Accuracy: 85.76%, Time: 220.24s, LR: 0.000020


Epoch 2/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.23it/s, Loss=1.2700]


Epoch 2, Train Loss: 1.6632, Val Loss: 1.1461, Train Accuracy: 73.97%, Val Accuracy: 87.59%, Time: 222.12s, LR: 0.000040


Epoch 3/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=0.9391]


Epoch 3, Train Loss: 1.5108, Val Loss: 1.1632, Train Accuracy: 79.11%, Val Accuracy: 87.09%, Time: 220.91s, LR: 0.000059


Epoch 4/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.25it/s, Loss=0.9713]


Epoch 4, Train Loss: 1.4212, Val Loss: 1.1688, Train Accuracy: 80.47%, Val Accuracy: 87.07%, Time: 221.60s, LR: 0.000079


Epoch 5/30: 100%|██████████| 1250/1250 [03:17<00:00,  6.32it/s, Loss=1.5374]


Epoch 5 - Test Loss: 1.2072, Test Accuracy: 85.41%
Epoch 5, Train Loss: 1.4219, Val Loss: 1.1986, Train Accuracy: 80.07%, Val Accuracy: 85.96%, Time: 219.02s, LR: 0.000098


Epoch 6/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=1.2247]


New best model saved at epoch 6 with Val Accuracy: 87.48%
Epoch 6, Train Loss: 1.3679, Val Loss: 1.1696, Train Accuracy: 81.34%, Val Accuracy: 87.48%, Time: 220.64s, LR: 0.000095


Epoch 7/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=1.1226]


Epoch 7, Train Loss: 1.3142, Val Loss: 1.2018, Train Accuracy: 83.46%, Val Accuracy: 86.82%, Time: 220.68s, LR: 0.000092


Epoch 8/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.24it/s, Loss=1.6309]


Epoch 8, Train Loss: 1.2876, Val Loss: 1.1841, Train Accuracy: 85.13%, Val Accuracy: 87.34%, Time: 221.65s, LR: 0.000089


Epoch 9/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=0.8716]


New best model saved at epoch 9 with Val Accuracy: 87.85%
Epoch 9, Train Loss: 1.2340, Val Loss: 1.1861, Train Accuracy: 85.53%, Val Accuracy: 87.85%, Time: 220.61s, LR: 0.000085


Epoch 10/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.26it/s, Loss=0.7920]


Epoch 10 - Test Loss: 1.2176, Test Accuracy: 87.26%
Epoch 10, Train Loss: 1.2129, Val Loss: 1.2002, Train Accuracy: 86.32%, Val Accuracy: 87.69%, Time: 221.07s, LR: 0.000081


Epoch 11/30: 100%|██████████| 1250/1250 [03:18<00:00,  6.29it/s, Loss=0.7885]


New best model saved at epoch 11 with Val Accuracy: 87.93%
Epoch 11, Train Loss: 1.2142, Val Loss: 1.2124, Train Accuracy: 87.13%, Val Accuracy: 87.93%, Time: 220.25s, LR: 0.000076


Epoch 12/30: 100%|██████████| 1250/1250 [03:18<00:00,  6.30it/s, Loss=1.7682]


Epoch 12, Train Loss: 1.1940, Val Loss: 1.2064, Train Accuracy: 87.06%, Val Accuracy: 87.87%, Time: 219.72s, LR: 0.000072


Epoch 13/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.26it/s, Loss=1.6958]


New best model saved at epoch 13 with Val Accuracy: 87.96%
Epoch 13, Train Loss: 1.1698, Val Loss: 1.2021, Train Accuracy: 87.78%, Val Accuracy: 87.96%, Time: 221.02s, LR: 0.000067


Epoch 14/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.25it/s, Loss=0.7825]


New best model saved at epoch 14 with Val Accuracy: 88.80%
Epoch 14, Train Loss: 1.1485, Val Loss: 1.1889, Train Accuracy: 88.93%, Val Accuracy: 88.80%, Time: 221.38s, LR: 0.000062


Epoch 15/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.23it/s, Loss=1.6923]


Epoch 15 - Test Loss: 1.1939, Test Accuracy: 88.22%
Epoch 15, Train Loss: 1.1536, Val Loss: 1.1954, Train Accuracy: 88.25%, Val Accuracy: 88.64%, Time: 221.75s, LR: 0.000057


Epoch 16/30: 100%|██████████| 1250/1250 [03:18<00:00,  6.29it/s, Loss=1.6843]


New best model saved at epoch 16 with Val Accuracy: 88.98%
Epoch 16, Train Loss: 1.1330, Val Loss: 1.1758, Train Accuracy: 88.62%, Val Accuracy: 88.98%, Time: 220.14s, LR: 0.000052


Epoch 17/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.24it/s, Loss=1.3786]


Epoch 17, Train Loss: 1.1418, Val Loss: 1.1899, Train Accuracy: 87.81%, Val Accuracy: 88.86%, Time: 221.90s, LR: 0.000047


Epoch 18/30: 100%|██████████| 1250/1250 [03:21<00:00,  6.21it/s, Loss=0.7805]


Epoch 18, Train Loss: 1.1323, Val Loss: 1.1753, Train Accuracy: 88.92%, Val Accuracy: 88.93%, Time: 222.78s, LR: 0.000042


Epoch 19/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=1.4176]


New best model saved at epoch 19 with Val Accuracy: 89.18%
Epoch 19, Train Loss: 1.1262, Val Loss: 1.1661, Train Accuracy: 88.49%, Val Accuracy: 89.18%, Time: 220.57s, LR: 0.000038


Epoch 20/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.28it/s, Loss=1.4257]


Epoch 20 - Test Loss: 1.1680, Test Accuracy: 89.25%
New best model saved at epoch 20 with Val Accuracy: 89.25%
Epoch 20, Train Loss: 1.1247, Val Loss: 1.1708, Train Accuracy: 89.45%, Val Accuracy: 89.25%, Time: 220.28s, LR: 0.000034


Epoch 21/30: 100%|██████████| 1250/1250 [03:22<00:00,  6.17it/s, Loss=1.5956]


New best model saved at epoch 21 with Val Accuracy: 89.77%
Epoch 21, Train Loss: 1.0988, Val Loss: 1.1550, Train Accuracy: 89.53%, Val Accuracy: 89.77%, Time: 224.07s, LR: 0.000029


Epoch 22/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.26it/s, Loss=1.4501]


New best model saved at epoch 22 with Val Accuracy: 89.93%
Epoch 22, Train Loss: 1.1095, Val Loss: 1.1490, Train Accuracy: 89.11%, Val Accuracy: 89.93%, Time: 221.32s, LR: 0.000026


Epoch 23/30: 100%|██████████| 1250/1250 [03:21<00:00,  6.19it/s, Loss=0.7789]


New best model saved at epoch 23 with Val Accuracy: 89.94%
Epoch 23, Train Loss: 1.0946, Val Loss: 1.1506, Train Accuracy: 88.95%, Val Accuracy: 89.94%, Time: 223.52s, LR: 0.000022


Epoch 24/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=1.8814]


Epoch 24, Train Loss: 1.0818, Val Loss: 1.1480, Train Accuracy: 90.64%, Val Accuracy: 89.77%, Time: 220.66s, LR: 0.000019


Epoch 25/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.25it/s, Loss=1.5058]


Epoch 25 - Test Loss: 1.1487, Test Accuracy: 89.70%
New best model saved at epoch 25 with Val Accuracy: 90.17%
Epoch 25, Train Loss: 1.0741, Val Loss: 1.1363, Train Accuracy: 90.44%, Val Accuracy: 90.17%, Time: 221.20s, LR: 0.000016


Epoch 26/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.25it/s, Loss=0.7789]


New best model saved at epoch 26 with Val Accuracy: 90.29%
Epoch 26, Train Loss: 1.0819, Val Loss: 1.1323, Train Accuracy: 90.32%, Val Accuracy: 90.29%, Time: 221.21s, LR: 0.000014


Epoch 27/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.24it/s, Loss=0.7787]


Epoch 27, Train Loss: 1.0856, Val Loss: 1.1367, Train Accuracy: 89.30%, Val Accuracy: 90.19%, Time: 221.67s, LR: 0.000012


Epoch 28/30: 100%|██████████| 1250/1250 [03:20<00:00,  6.22it/s, Loss=0.7786]


Epoch 28, Train Loss: 1.0843, Val Loss: 1.1350, Train Accuracy: 89.22%, Val Accuracy: 90.19%, Time: 222.11s, LR: 0.000011


Epoch 29/30: 100%|██████████| 1250/1250 [03:19<00:00,  6.27it/s, Loss=1.1290]


New best model saved at epoch 29 with Val Accuracy: 90.33%
Epoch 29, Train Loss: 1.0707, Val Loss: 1.1307, Train Accuracy: 90.90%, Val Accuracy: 90.33%, Time: 220.73s, LR: 0.000010


Epoch 30/30: 100%|██████████| 1250/1250 [03:21<00:00,  6.20it/s, Loss=0.7785]


Epoch 30 - Test Loss: 1.1354, Test Accuracy: 90.19%
Epoch 30, Train Loss: 1.0785, Val Loss: 1.1352, Train Accuracy: 90.83%, Val Accuracy: 90.30%, Time: 223.13s, LR: 0.000010


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Model (Epoch 29) - Test Loss: 1.1345, Test Accuracy: 90.18%
ECE: 0.0560, Scaled ECE: 0.0266, Scaled Test Accuracy: 90.18%
Class-wise Accuracy: Mean 0.90, Std 0.07
Total training time: 6820.82 seconds
