# 🚀 Fine-Tuning a Vision Transformer (ViT) on a Custom Dataset

This notebook allows you to fine-tune a pre-trained Vision Transformer model from Hugging Face on your own image dataset.

### How to Use This Notebook:
1.  **Prepare Your Data**: Your training and validation images must be in `.zip` files.
2.  **Follow the Steps**: Run each cell in order from top to bottom.
3.  **Upload Your Data**: When prompted, upload your zipped dataset files.
4.  **Train**: The model will train on your data.
5.  **Download**: A final cell will let you download your trained model.

### ⚠️ Important: Data Structure

Your `.zip` files must contain folders where **each folder's name is the class label**. For example, if you have two classes, `cats` and `dogs`, your `train_data.zip` should have this structure:

```
train_data.zip
├── cats/
│   ├── cat_image_1.jpg
│   ├── cat_image_2.png
│   └── ...
└── dogs/
    ├── dog_image_1.jpeg
    ├── dog_image_2.jpg
    └── ...
```

The same structure is required for your validation data `.zip` file.

## Step 0: Setup and Installations

First, we'll install the necessary libraries from Hugging Face. We also check for GPU availability, as a GPU is highly recommended for this task.

In [None]:
%pip install -q transformers[torch] datasets accelerate

import torch

# Check if a GPU is available and print the device name
if torch.cuda.is_available():
    print(f"GPU is available. Using device: {torch.cuda.get_device_name(0)}")
    !nvidia-smi # Display GPU stats
else:
    print("GPU not available. Training will run on CPU, which will be very slow.")

## Step 1: Configuration

Here, you can set the key parameters for the training process. You can change the pre-trained model, the number of training epochs, and the batch size. Adjust the batch size based on your GPU's VRAM (if you get an 'Out of Memory' error, try reducing it).

In [None]:
# --- Main Configuration ---
MODEL_NAME = "google/vit-base-patch16-224-in21k"
OUTPUT_DIR = "./output/"

# --- Training Hyperparameters ---
NUM_TRAIN_EPOCHS = 6
# Adjust batch size based on your GPU's VRAM. 
# If you have a T4 (16GB), 32 should be safe. For an A100 (40GB), you can go higher.
PER_DEVICE_TRAIN_BATCH_SIZE = 32 
PER_DEVICE_EVAL_BATCH_SIZE = 32
LEARNING_RATE = 5e-5
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.01
LOGGING_STEPS = 50

## Step 2: Upload and Unzip Datasets

Run the cells below to upload your `training` and `validation` zip files. The code will automatically unzip them into the correct directories for the script.

In [None]:
import os
import zipfile
from google.colab import files

# --- Upload and Unzip Training Data ---
print("Please upload your training data zip file.")
uploaded_train = files.upload()

if not uploaded_train:
  raise Exception("No training file was uploaded. Please restart the runtime and try again.")

train_zip_name = list(uploaded_train.keys())[0]
TRAIN_DATA_DIR = os.path.splitext(train_zip_name)[0]

# Unzip the file
with zipfile.ZipFile(train_zip_name, 'r') as zip_ref:
    zip_ref.extractall('.')
print(f"Successfully unzipped training data to '{TRAIN_DATA_DIR}'")

# --- Upload and Unzip Validation Data (Optional) ---
VAL_DATA_DIR = None # Default to None
print("\n(Optional) Please upload your validation data zip file. If you don't have one, just press 'Cancel upload'.")
uploaded_val = files.upload()

if uploaded_val:
    val_zip_name = list(uploaded_val.keys())[0]
    VAL_DATA_DIR = os.path.splitext(val_zip_name)[0]
    # Unzip the file
    with zipfile.ZipFile(val_zip_name, 'r') as zip_ref:
        zip_ref.extractall('.')
    print(f"Successfully unzipped validation data to '{VAL_DATA_DIR}'")
else:
    print("No validation file uploaded. Training will proceed without a validation set.")

## Step 3: Load Data and Prepare for Training

Now we load the images from the directories you just created. We'll use `torchvision`'s `ImageFolder` which automatically finds class labels from the folder names. We'll also define the model, tokenizer (processor), and the necessary functions for data collation and metrics calculation.

In [None]:
from torchvision import datasets
from transformers import ViTImageProcessor, AutoModelForImageClassification
from sklearn.metrics import accuracy_score
from PIL import Image # Pillow is needed for image handling by torchvision

# --- 1. Load Data with torchvision ---
if not os.path.exists(TRAIN_DATA_DIR):
    raise FileNotFoundError(f"Training data directory not found: {TRAIN_DATA_DIR}")

# Use torchvision's ImageFolder to load the datasets
train_dataset = datasets.ImageFolder(TRAIN_DATA_DIR)

# Get class names and mappings from the training dataset
class_names = train_dataset.classes
label2id = {name: i for i, name in enumerate(class_names)}
id2label = {i: name for i, name in enumerate(class_names)}
num_labels = len(class_names)

print(f"Found {num_labels} classes: {class_names}")
print(f"Training data found: {len(train_dataset)} images.")

val_dataset = None
if VAL_DATA_DIR and os.path.exists(VAL_DATA_DIR):
    val_dataset = datasets.ImageFolder(VAL_DATA_DIR)
    print(f"Validation data found: {len(val_dataset)} images.")
else:
    print(f"Warning: Validation data directory not provided or not found. Skipping evaluation during training.")


# --- 2. Load Model and Processor ---
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True # Allows loading a pre-trained model with a different head
)

print("\nModel and Processor loaded.")

# --- 3. Data Collator ---
# This function processes batches of (PIL Image, label) tuples from ImageFolder
# and prepares them for the model using the Hugging Face processor.
def collate_fn(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    # The processor handles resizing, normalization, and tensor conversion.
    batch_processor_output = processor(images, return_tensors="pt")
    batch_processor_output['labels'] = torch.tensor(labels)
    return batch_processor_output

print("Collate function defined.")

# --- 4. Define Metrics ---
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predicted_labels = predictions.argmax(axis=1)
    return {"accuracy": accuracy_score(labels, predicted_labels)}

print("Metrics function defined.")

## Step 4: Configure Training

Now we set up the `TrainingArguments` and the `Trainer`. 

- **`TrainingArguments`**: This object holds all the hyperparameters for the training run (like learning rate, batch size, etc.). We will use the parameters defined in the configuration step.
- **`Trainer`**: This is the main Hugging Face class that orchestrates the entire training and evaluation loop.

We also enable TensorBoard for real-time monitoring of the training loss and evaluation accuracy.

In [None]:
from transformers import TrainingArguments, Trainer

# Configure Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    weight_decay=WEIGHT_DECAY,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_strategy="steps",
    logging_steps=LOGGING_STEPS,
    # Use bf16 for faster training on compatible GPUs (like A100s in Colab Pro)
    bf16=torch.cuda.is_bf16_supported(), 
    fp16=not torch.cuda.is_bf16_supported(),
    # Set evaluation and saving strategies if a validation set is present
    evaluation_strategy="epoch" if val_dataset else "no",
    save_strategy="epoch" if val_dataset else "no",
    load_best_model_at_end=True if val_dataset else False,
    metric_for_best_model="accuracy" if val_dataset else None,
    greater_is_better=True if val_dataset else None,
    # Report logs to TensorBoard
    report_to="tensorboard",
    # Helps speed up data loading
    dataloader_num_workers=2, 
)

print("Training arguments configured.")

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor, # The processor is passed as a tokenizer
    data_collator=collate_fn,
    compute_metrics=compute_metrics if val_dataset else None,
)

print("Trainer initialized.")

# Launch TensorBoard in the background (optional)
%load_ext tensorboard
%tensorboard --logdir '{OUTPUT_DIR}/logs'

## Step 5: Start Training!

This is the moment of truth. Running the cell below will start the fine-tuning process. You can monitor the progress here and in the TensorBoard panel above.

In [None]:
print("Starting training...")
trainer.train()
print("Training finished.")

## Step 6: Evaluate and Save the Final Model

After training, we'll run a final evaluation on the validation set (if you provided one) to see how the best model performs. Then, we save the final model and its processor to the output directory. This saved model can be easily loaded later for inference.

In [None]:
# --- Evaluate (Optional) ---
if val_dataset:
    print("\nEvaluating the best model on the validation set...")
    metrics = trainer.evaluate()
    print("Final evaluation metrics:")
    print(metrics)

# --- Save the final model ---
FINAL_MODEL_DIR = f"{OUTPUT_DIR}/final_model"
print(f"\nSaving the final (or best) model and processor to {FINAL_MODEL_DIR}")
trainer.save_model(FINAL_MODEL_DIR)
processor.save_pretrained(FINAL_MODEL_DIR)

print("Fine-tuning complete. Model saved successfully.")

## Step 7: Download Your Trained Model

The final step is to download the model you just trained. The cell below will zip the contents of the `final_model` directory and start a download in your browser.

In [None]:
import shutil
from google.colab import files

# Zip the final model directory
output_filename = 'fine_tuned_vit_model'
shutil.make_archive(output_filename, 'zip', FINAL_MODEL_DIR)

print(f"Model files zipped into '{output_filename}.zip'")

# Download the zipped file
print("Starting download...")
files.download(f'{output_filename}.zip')