In [20]:
from transformers import ViTForImageClassification
from torch.optim import AdamW
from datasets import load_dataset
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn.functional as F

In [4]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [5]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("junkal/flowerdatasets")

print("Path to dataset files:", path)

Mounting files to /kaggle/input/flowerdatasets...
Path to dataset files: /kaggle/input/flowerdatasets


In [6]:
train_dir = '/kaggle/input/flowerdatasets/flowers/train'
test_dir = '/kaggle/input/flowerdatasets/flowers/test'
val_dir = '/kaggle/input/flowerdatasets/flowers/val'

In [11]:
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)

In [12]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [13]:
num_classes = len(train_dataset.classes)

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes
).to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
optimizer = AdamW(model.parameters(), lr=3e-5)

In [21]:
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_correct = 0
    for batch in train_loader:
        images, labels = batch[0].to(device), batch[1].to(device)
        outputs = model(images).logits
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss/len(train_loader):.4f} - Accuracy: {total_correct/len(train_dataset):.4f}")

# Optional: validation
model.eval()
val_correct = 0
with torch.no_grad():
    for batch in val_loader:
        images, labels = batch[0].to(device), batch[1].to(device)
        outputs = model(images).logits
        preds = outputs.argmax(dim=1)
        val_correct += (preds == labels).sum().item()
print(f"Validation Accuracy: {val_correct/len(val_dataset):.4f}")

Epoch 1/3 - Loss: 0.4646 - Accuracy: 0.9751
Epoch 2/3 - Loss: 0.0918 - Accuracy: 0.9994
Epoch 3/3 - Loss: 0.0505 - Accuracy: 1.0000
Validation Accuracy: 0.9991
