# 1️⃣ Install Libraries

In [1]:
# Install required packages quietly
!pip install transformers datasets torch torchvision --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m105.9 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m77.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Explanation:
Install Hugging Face Transformers, Datasets, PyTorch, and Torchvision. The --quiet flag keeps output cleaner.

# 2️⃣ Import Libraries and Set Device

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from datasets import load_dataset
from PIL import Image

# Set device: GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


2025-09-24 19:02:34.598963: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758740554.780799      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758740554.840805      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda


### Explanation:
All libraries needed for loading data, preprocessing, model, and training. The device line ensures GPU usage if available.

# 3️⃣ Load Dataset

In [3]:
# Load Hugging Face dataset
dataset = load_dataset("aaronqg/golden-foot-football-players")
print(dataset)


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/400M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/98.2M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/57.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5175 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1294 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/719 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5175
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1294
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 719
    })
})


### Explanation:
Loads the Golden Foot Football Players dataset. It has train, validation, and test splits.

# 4️⃣ Create Custom PyTorch Dataset

In [4]:
class FootballDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        # Convert string labels to integer if needed
        label = int(item['label'].split()[0]) if isinstance(item['label'], str) else item['label']

        if self.transform:
            image = self.transform(image)
        return image, label


### Explanation:
Wraps the Hugging Face dataset into a PyTorch Dataset with optional transforms. Converts string labels to integers if necessary.

# 5️⃣ Define Image Transformations and DataLoaders

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT expects 224x224 input
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

train_dataset = FootballDataset(dataset['train'], transform=transform)
val_dataset = FootballDataset(dataset['validation'], transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)


# Explanation:
Preprocess images for ViT and create PyTorch DataLoaders for batching and shuffling.

# 6️⃣ Initialize ViT Model for Classification

In [6]:
num_classes = 22  # Number of players

# Create configuration with correct number of labels
config = ViTConfig.from_pretrained("google/vit-base-patch16-224", num_labels=num_classes)

# Load pre-trained ViT and automatically re-initialize classifier if sizes mismatch
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    config=config,
    ignore_mismatched_sizes=True
)

# Freeze ViT backbone (only fine-tune classifier)
for param in model.vit.parameters():
    param.requires_grad = False

model.to(device)


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([22]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([22, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

### Explanation:
Uses ignore_mismatched_sizes=True to replace the default 1000-class classifier with our 22-class one. Only the classifier layer will be trained.

# 7️⃣ Define Loss Function and Optimizer

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-4)


### Explanation:
Cross-entropy for multi-class classification. Only the classifier’s parameters are optimized

# 8️⃣ Training Loop

In [8]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


Epoch [1/100], Loss: 2.8397
Epoch [2/100], Loss: 2.3649
Epoch [3/100], Loss: 2.0749
Epoch [4/100], Loss: 1.8695
Epoch [5/100], Loss: 1.7097
Epoch [6/100], Loss: 1.5892
Epoch [7/100], Loss: 1.4870
Epoch [8/100], Loss: 1.4024
Epoch [9/100], Loss: 1.3302
Epoch [10/100], Loss: 1.2652
Epoch [11/100], Loss: 1.2096
Epoch [12/100], Loss: 1.1584
Epoch [13/100], Loss: 1.1137
Epoch [14/100], Loss: 1.0711
Epoch [15/100], Loss: 1.0341
Epoch [16/100], Loss: 0.9996
Epoch [17/100], Loss: 0.9646
Epoch [18/100], Loss: 0.9367
Epoch [19/100], Loss: 0.9073
Epoch [20/100], Loss: 0.8839
Epoch [21/100], Loss: 0.8578
Epoch [22/100], Loss: 0.8317
Epoch [23/100], Loss: 0.8112
Epoch [24/100], Loss: 0.7914
Epoch [25/100], Loss: 0.7722
Epoch [26/100], Loss: 0.7534
Epoch [27/100], Loss: 0.7343
Epoch [28/100], Loss: 0.7173
Epoch [29/100], Loss: 0.7011
Epoch [30/100], Loss: 0.6850
Epoch [31/100], Loss: 0.6692
Epoch [32/100], Loss: 0.6539
Epoch [33/100], Loss: 0.6417
Epoch [34/100], Loss: 0.6255
Epoch [35/100], Loss: 0

# Explanation:
Standard PyTorch training loop. Computes loss, backpropagates, updates weights, and prints average loss per epoch.

# 9️⃣ Validation Accuracy

In [9]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs).logits
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Validation Accuracy: {100 * correct / total:.2f}%')


Validation Accuracy: 78.28%


### Explanation:
Evaluate the model on validation data without updating weights. Calculates and prints accuracy.

# 🔟 Save the Trained Model

In [10]:
# Save the entire model and tokenizer/config (optional)
model_save_path = "./vit_football_classifier"
model.save_pretrained(model_save_path)
print(f"Model saved to {model_save_path}")


Model saved to ./vit_football_classifier


### Explanation:

save_pretrained saves both the model weights and configuration.

You can later reload it with ViTForImageClassification.from_pretrained(model_save_path).

Useful for deployment, sharing, or continuing training.