# 1️⃣ Install Libraries

In [3]:
# Install required packages quietly
!pip install transformers datasets torch torchvision --quiet


### Explanation:
Install Hugging Face Transformers, Datasets, PyTorch, and Torchvision. The --quiet flag keeps output cleaner.

# 2️⃣ Import Libraries and Set Device

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
from datasets import load_dataset
from PIL import Image

# Set device: GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


### Explanation:
All libraries needed for loading data, preprocessing, model, and training. The device line ensures GPU usage if available.

# 3️⃣ Load Dataset

In [None]:
# Load Hugging Face dataset
dataset = load_dataset("aaronqg/golden-foot-football-players")
print(dataset)


### Explanation:
Loads the Golden Foot Football Players dataset. It has train, validation, and test splits.

# 4️⃣ Create Custom PyTorch Dataset

In [None]:
class FootballDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        # Convert string labels to integer if needed
        label = int(item['label'].split()[0]) if isinstance(item['label'], str) else item['label']

        if self.transform:
            image = self.transform(image)
        return image, label


### Explanation:
Wraps the Hugging Face dataset into a PyTorch Dataset with optional transforms. Converts string labels to integers if necessary.

# 5️⃣ Define Image Transformations and DataLoaders

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT expects 224x224 input
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

train_dataset = FootballDataset(dataset['train'], transform=transform)
val_dataset = FootballDataset(dataset['validation'], transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)


# Explanation:
Preprocess images for ViT and create PyTorch DataLoaders for batching and shuffling.

# 6️⃣ Initialize ViT Model for Classification

In [None]:
num_classes = 22  # Number of players

# Create configuration with correct number of labels
config = ViTConfig.from_pretrained("google/vit-base-patch16-224", num_labels=num_classes)

# Load pre-trained ViT and automatically re-initialize classifier if sizes mismatch
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    config=config,
    ignore_mismatched_sizes=True
)

# Freeze ViT backbone (only fine-tune classifier)
for param in model.vit.parameters():
    param.requires_grad = False

model.to(device)


### Explanation:
Uses ignore_mismatched_sizes=True to replace the default 1000-class classifier with our 22-class one. Only the classifier layer will be trained.

# 7️⃣ Define Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-4)


### Explanation:
Cross-entropy for multi-class classification. Only the classifier’s parameters are optimized

# 8️⃣ Training Loop

In [None]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


# Explanation:
Standard PyTorch training loop. Computes loss, backpropagates, updates weights, and prints average loss per epoch.

# 9️⃣ Validation Accuracy

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs).logits
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Validation Accuracy: {100 * correct / total:.2f}%')


### Explanation:
Evaluate the model on validation data without updating weights. Calculates and prints accuracy.

# 🔟 Save the Trained Model

In [None]:
# Save the entire model and tokenizer/config (optional)
model_save_path = "./vit_football_classifier"
model.save_pretrained(model_save_path)
print(f"Model saved to {model_save_path}")


### Explanation:

save_pretrained saves both the model weights and configuration.

You can later reload it with ViTForImageClassification.from_pretrained(model_save_path).

Useful for deployment, sharing, or continuing training.