In [49]:
from transformers import DeiTForImageClassification, DeiTFeatureExtractor
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import timm
from tqdm import tqdm

In [50]:
torch.cuda.is_available()

True

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model

In [52]:
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float32)
model.classifier = nn.Linear(model.classifier.in_features, 10)
model = model.to("cuda")

Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [53]:
extractor = DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")

## Dataset

In [54]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [55]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

## Training

In [56]:
import torch.optim as optim

In [57]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

In [58]:
model.train()
for epoch in range(3):
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        pil_images = [transforms.ToPILImage()(img) for img in images]
        inputs = extractor(images=pil_images, return_tensors="pt").to(device)
        labels = labels.to(device)

        outputs = model(**inputs)
        loss = criterion(outputs.logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {running_loss/len(train_loader):.4f}")

100%|██████████| 3125/3125 [10:30<00:00,  4.96it/s]


Epoch 1: Loss = 0.1434


100%|██████████| 3125/3125 [10:31<00:00,  4.94it/s]


Epoch 2: Loss = 0.0551


100%|██████████| 3125/3125 [10:32<00:00,  4.94it/s]

Epoch 3: Loss = 0.0398





## Test

In [61]:
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [60]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(test_loader):
        pil_images = [transforms.ToPILImage()(img) for img in x]
        inputs = extractor(images=pil_images, return_tensors="pt").to(device)
        y = y.to(device)

        outputs = model(**inputs)
        preds = outputs.logits.argmax(dim=1)

        correct += (preds == y).sum().item()
        total += y.size(0)

print(f"\n🎯 Accuracy on CIFAR-10 test set: {100 * correct / total:.2f}%")

100%|██████████| 313/313 [00:59<00:00,  5.24it/s]


🎯 Accuracy on CIFAR-10 test set: 96.74%



