# Fine-Tuning Vision Transformers with LoRA for Binary Classification

In this notebook, we implement a robust pipeline for fine-tuning a pre-trained **Vision Transformer (ViT)** on a custom dataset using **Low-Rank Adaptation (LoRA)**.

### Technical Objectives
1.  **Architecture**: Leverage `vit-base-patch16-224-in21k` as a feature extractor backbone.
2.  **Parameter Efficiency**: Implement LoRA to reduce traininable parameters by ~99% while maintaining performance.
3.  **Deployment**: Construct an inference pipeline capable of real-time prediction using Gradio.

### Stack
*   **Hugging Face Transformers**: Model architecture and pre-trained weights.
*   **PEFT**: Parameter-Efficient Fine-Tuning implementation.
*   **PyTorch**: Deep learning backend.

## 1. Environment Setup
Installing necessary dependencies for the transformer ecosystem and dynamic dataset handling.

In [1]:
# %pip install torch transformers peft datasets pillow numpy gradio evaluate scikit-learn kagglehub

In [2]:
import os
import torch
import numpy as np
from PIL import Image
import kagglehub
from datasets import Dataset

# Auto-select CUDA if available for accelerated training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Compute backend: {device}")

Compute backend: cuda


## 2. ELT Pipeline (Extract, Load, Transform)

We ingest the **[`louisiana-flood-2016`](https://www.kaggle.com/datasets/rahultp97/louisiana-flood-2016)** dataset sourced from Kaggle. 

The pipeline involves:
1.  **Extraction**: Programmatic download via `kagglehub`.
2.  **Label Parsing**: Binary labels are encoded from filenames (`_1.png` -> Positive Class).
3.  **IO Handling**: We store file paths to manage memory consumption efficiently, loading images into RAM only during the transformation stage.

In [3]:
print("Initializing data download...")
DATA_DIR = kagglehub.dataset_download("rahultp97/louisiana-flood-2016")
print(f"Dataset cache path: {DATA_DIR}")

def extract_metadata(data_dir):
    """Parses directory structure to return image paths and corresponding binary labels."""
    image_paths = []
    labels = []
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.lower().endswith(('.png', '.jpg')):
                image_paths.append(os.path.join(root, file))
                # Label Encoding: 1 = Flooded, 0 = Non-Flooded
                labels.append(1 if file.endswith("_1.png") else 0)
    return image_paths, labels

# Split extraction assumes standard train/test directory structure
train_paths, train_labels = extract_metadata(os.path.join(DATA_DIR, "train"))
test_paths, test_labels = extract_metadata(os.path.join(DATA_DIR, "test"))

print(f"Train samples: {len(train_paths)} | Test samples: {len(test_paths)}")

Initializing data download...
Dataset cache path: /root/.cache/kagglehub/datasets/rahultp97/louisiana-flood-2016/versions/4
Train samples: 270 | Test samples: 52


## 3. Model Architecture Initialization

We utilize a **ViT-Base** architecture. 
*   **Patch Size**: 16x16 (Standard for 224x224 input).
*   **Classification Head**: The original 21k-class MLP head is replaced with a randomly initialized binary classification head (`num_labels=2`).
*   **Processor**: Handles normalization (ImageNet mean/std) and resizing to ensure tensor compatibility.

In [4]:
from transformers import ViTForImageClassification, ViTImageProcessor

MODEL_CKPT = "google/vit-base-patch16-224-in21k"

# Initialize Preprocessor for ImageNet standardization
processor = ViTImageProcessor.from_pretrained(MODEL_CKPT)

# Initialize Model with custom Classification Head
model = ViTForImageClassification.from_pretrained(
    MODEL_CKPT,
    num_labels=2,
    id2label={0: "Non-Flooded", 1: "Flooded"},
    label2id={"Non-Flooded": 0, "Flooded": 1}
)
model.to(device);

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 4. Parameter-Efficient Fine-Tuning (PEFT) Configuration

We implement **LoRA (Low-Rank Adaptation)**. Instead of full-parameter fine-tuning, we inject low-rank decomposition matrices into the attention blocks.

### LoRA Hyperparameters
*   **Rank (`r=16`)**: Determines the dimensionality of the low-rank matrices. Higher `r` allows more expressivity but increases parameters.
*   **Target Modules**: We adapt the `query` and `value` projections of the Self-Attention mechanism, which has empirically shown high yield for transfer learning tasks.
*   **Alpha**: Scaling factor for the learned weights.

In [5]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"] # We must train the new head fully
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 591,362 || all params: 86,391,556 || trainable%: 0.6845


**Observation**: The output confirms we are training <1% of the total parameters. This allows for significantly higher batch sizes and lower memory VRAM footprint.

## 5. Training Loop Execution
We utilize the Hugging Face `Trainer` API for abstraction of the training loop, gradient accumulation, and evaluation.

### Key Components:
*   **Collate Function**: Dynamically stacks tensors into batches.
*   **Transform**: Applies the ViT processor to the PIL images on-the-fly.
*   **Metric**: Accuracy is used as the primary evaluation metric.

In [6]:
from transformers import TrainingArguments, Trainer
import evaluate

# Data Transformation: PIL -> Tensor
def transform_batch(batch):
    # return_tensors='pt' ensures PyTorch compatibility
    inputs = processor([x for x in batch['image']], return_tensors='pt')
    inputs['label'] = batch['label']
    return inputs

# Custom Collate: Handles batch stacking
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

def compute_metrics(eval_pred):
    metric = evaluate.load("accuracy")
    preds = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=preds, references=eval_pred.label_ids)

# Dataset Generation Wrapper
def create_hf_dataset(paths, labels):
    def gen():
        for p, l in zip(paths, labels):
            yield {"image": Image.open(p).convert("RGB"), "label": l}
    return Dataset.from_generator(gen)

# Instantiating Pipelines
train_ds = create_hf_dataset(train_paths, train_labels)
test_ds = create_hf_dataset(test_paths, test_labels)

# Apply transforms
train_ds = train_ds.with_transform(transform_batch)
test_ds = test_ds.with_transform(transform_batch)

# Training Hyperparameters
args = TrainingArguments(
    output_dir="./results",                 # Directory to store checkpoints and logs
    per_device_train_batch_size=8,           # Small batch size to fit GPU memory
    eval_strategy="epoch",                   # Run evaluation after each epoch
    save_strategy="epoch",                   # Save model checkpoints per epoch
    num_train_epochs=3,                      # Limited epochs to prevent overfitting
    learning_rate=5e-3,                      # Higher LR suitable for LoRA fine-tuning
    remove_unused_columns=False,             # Required for custom data collator
    logging_steps=10                         # Log training metrics every N steps
)

trainer = Trainer(
    model=model,                             # LoRA-adapted ViT model
    args=args,                               # Training configuration
    train_dataset=train_ds,                  # Training split
    eval_dataset=test_ds,                    # Validation split
    data_collator=collate_fn,                # Custom batch collation logic
    compute_metrics=compute_metrics           # Evaluation metrics (e.g., accuracy, F1)
)


trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4074,0.155756,0.942308
2,0.0026,0.01157,1.0
3,0.0725,0.014748,1.0


TrainOutput(global_step=102, training_loss=0.15123242276204824, metrics={'train_runtime': 35.7454, 'train_samples_per_second': 22.66, 'train_steps_per_second': 2.854, 'total_flos': 6.320113196802048e+16, 'train_loss': 0.15123242276204824, 'epoch': 3.0})

## 6. Inference Deployment
We encapsulate the forward pass logic into a prediction function and expose it via a REST-like interface using Gradio. This serves as a rapid prototype for model validation.

In [7]:
import gradio as gr

def inference_pipeline(image):
    if image is None: return "No input."
    
    # Preprocessing
    inputs = processor(image, return_tensors="pt").to(device)
    
    # Inference (No Gradient Calculation needed)
    with torch.no_grad():
        outputs = model(**inputs)
        # Softmax for probability distribution
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    
    # Post-processing
    score = probs.max().item()
    label_idx = probs.argmax().item()
    label = model.config.id2label[label_idx]
    
    return f"{label} (Confidence: {score:.4f})"

iface = gr.Interface(
    fn=inference_pipeline,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs="text",
    title="ViT Flood Detection Module",
    description="Real-time inference using LoRA-tuned Vision Transformer."
)

iface.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://cd5febd8f4a2980a36.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [8]:
gr.close_all()

Closing server running on port: 7860


### Thank you for your time