# Vision Transformer (ViT) Fine-tuning on Food-101 Dataset

Fine-tuning `google/vit-base-patch16-224` for food image classification using LoRA and 16-bit precision.

In [None]:
!pip install -q transformers datasets accelerate peft bitsandbytes pillow scikit-learn matplotlib seaborn torch torchvision

In [None]:
import os
import torch
import pickle
import numpy as np
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from transformers import (
    ViTImageProcessor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from torch.utils.data import Dataset
from peft import LoraConfig, get_peft_model, TaskType

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")

In [None]:
BASE_DIR = Path('/kaggle/input/food41')
IMAGES_DIR = BASE_DIR / 'images'
META_DIR = BASE_DIR / 'meta' / 'meta'
OUTPUT_DIR = Path('/kaggle/working')
MODEL_NAME = 'google/vit-base-patch16-224'

CHECKPOINT_FILE = OUTPUT_DIR / 'training_checkpoint.pkl'
FINAL_MODEL_FILE = OUTPUT_DIR / 'final_model.pkl'
DATA_SUBSET_RATIO = 0.2
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 2e-4
SEED = 42

OUTPUT_DIR.mkdir(exist_ok=True)
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Configuration:")
print(f"  BASE_DIR: {BASE_DIR}")
print(f"  IMAGES_DIR: {IMAGES_DIR}")
print(f"  META_DIR: {META_DIR}")
print(f"  OUTPUT_DIR: {OUTPUT_DIR}")
print(f"  DATA_SUBSET_RATIO: {DATA_SUBSET_RATIO * 100}%")

In [None]:
with open(META_DIR / 'classes.txt', 'r') as f:
    classes = [line.strip() for line in f.readlines()]

print(f"Total number of classes: {len(classes)}")
print(f"First 10 classes: {classes[:10]}")

label2id = {label: idx for idx, label in enumerate(classes)}
id2label = {idx: label for label, idx in label2id.items()}

In [None]:
def load_image_paths(split='train'):
    file_path = META_DIR / f'{split}.txt'
    image_paths = []
    labels = []
    
    print(f"Loading {split} data from {file_path}...")
    
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            
            class_name = line.split('/')[0]
            img_name = line.split('/')[1]
            img_path = IMAGES_DIR / class_name / f'{img_name}.jpg'
            
            if img_path.exists():
                image_paths.append(str(img_path))
                labels.append(label2id[class_name])
    
    return image_paths, labels

train_paths, train_labels = load_image_paths('train')
test_paths, test_labels = load_image_paths('test')

print(f"Total training images: {len(train_paths)}")
print(f"Total test images: {len(test_paths)}")

In [None]:
subset_size = int(len(train_paths) * DATA_SUBSET_RATIO)
indices = np.random.choice(len(train_paths), subset_size, replace=False)

train_paths_subset = [train_paths[i] for i in indices]
train_labels_subset = [train_labels[i] for i in indices]

train_paths_final, val_paths, train_labels_final, val_labels = train_test_split(
    train_paths_subset, train_labels_subset, test_size=0.2, random_state=SEED, stratify=train_labels_subset
)

test_subset_size = int(len(test_paths) * DATA_SUBSET_RATIO)
test_indices = np.random.choice(len(test_paths), test_subset_size, replace=False)
test_paths_subset = [test_paths[i] for i in test_indices]
test_labels_subset = [test_labels[i] for i in test_indices]

print(f"Training samples ({int(DATA_SUBSET_RATIO*100)}% of dataset): {len(train_paths_final)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples ({int(DATA_SUBSET_RATIO*100)}% of test): {len(test_paths_subset)}")

In [None]:
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

class FoodDataset(Dataset):
    def __init__(self, image_paths, labels, processor, augment=False):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.augment = augment
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        
        if self.augment:
            from torchvision import transforms
            transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2)
            ])
            image = transform(image)
        
        encoding = self.processor(images=image, return_tensors='pt')
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        encoding['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return encoding

train_dataset = FoodDataset(train_paths_final, train_labels_final, processor, augment=True)
val_dataset = FoodDataset(val_paths, val_labels, processor, augment=False)
test_dataset = FoodDataset(test_paths_subset, test_labels_subset, processor, augment=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(classes),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['query', 'value'],
    lora_dropout=0.1,
    bias='none',
    modules_to_save=['classifier']
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    return {'accuracy': accuracy}

training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR / 'checkpoints'),
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_dir=str(OUTPUT_DIR / 'logs'),
    logging_steps=50,
    logging_first_step=True,
    eval_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=False,
    fp16=True,
    gradient_accumulation_steps=2,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    seed=SEED,
    report_to='none',
    include_inputs_for_metrics=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

print("Trainer configured with 20% dataset and loss computation enabled.")

In [None]:
if CHECKPOINT_FILE.exists():
    print("Loading checkpoint...")
    with open(CHECKPOINT_FILE, 'rb') as f:
        checkpoint_data = pickle.load(f)
    model.load_state_dict(checkpoint_data['model_state'])
    print("Checkpoint loaded successfully!")
else:
    print("No checkpoint found. Starting fresh training.")

In [None]:
print("Starting training...")
train_result = trainer.train()

print("\nTraining completed!")
print(f"Training loss: {train_result.training_loss:.4f}")

checkpoint_data = {
    'model_state': model.state_dict(),
    'training_loss': train_result.training_loss
}
with open(CHECKPOINT_FILE, 'wb') as f:
    pickle.dump(checkpoint_data, f)
print(f"Checkpoint saved to {CHECKPOINT_FILE}")

In [None]:
print("Evaluating on validation set...")
val_results = trainer.evaluate(eval_dataset=val_dataset)
print("\nValidation Results:")
for key, value in val_results.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

In [None]:
print("Evaluating on test set...")
test_results = trainer.evaluate(eval_dataset=test_dataset)
print("\nTest Results:")
for key, value in test_results.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

In [None]:
pred_labels = []
true_labels = test_labels_subset

model.eval()
for i in tqdm(range(len(test_dataset)), desc="Predicting"):
    sample = test_dataset[i]
    inputs = {k: v.unsqueeze(0).to(model.device) for k, v in sample.items() if k != 'labels'}
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        pred = torch.argmax(logits, dim=-1).cpu().numpy()[0]
        pred_labels.append(pred)

pred_labels = np.array(pred_labels)

print(f"Predicted labels shape: {pred_labels.shape}")
print(f"True labels length: {len(true_labels)}")

cm = confusion_matrix(true_labels, pred_labels)
print("\nClassification Report:")
print(classification_report(true_labels, pred_labels, target_names=classes, zero_division=0))

In [None]:
plt.figure(figsize=(15, 12))
sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', cbar=True)
plt.title('Confusion Matrix', fontsize=16, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Confusion matrix saved to {OUTPUT_DIR / 'confusion_matrix.png'}")

In [None]:
num_samples = 10
sample_indices = np.random.choice(len(test_dataset), num_samples, replace=False)

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.ravel()

for idx, sample_idx in enumerate(sample_indices):
    image = Image.open(test_paths_subset[sample_idx]).convert('RGB')
    true_label = id2label[test_labels_subset[sample_idx]]
    pred_label = id2label[pred_labels[sample_idx]]
    
    axes[idx].imshow(image)
    axes[idx].axis('off')
    color = 'green' if true_label == pred_label else 'red'
    axes[idx].set_title(f'True: {true_label}\nPred: {pred_label}', fontsize=10, color=color)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'sample_predictions.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Sample predictions saved to {OUTPUT_DIR / 'sample_predictions.png'}")

In [None]:
test_accuracy = test_results.get('eval_accuracy', test_results.get('accuracy', 0.0))

final_model_data = {
    'model_state': model.state_dict(),
    'processor': processor,
    'id2label': id2label,
    'label2id': label2id,
    'classes': classes,
    'test_accuracy': test_accuracy
}

with open(FINAL_MODEL_FILE, 'wb') as f:
    pickle.dump(final_model_data, f)

print(f"Final model saved to {FINAL_MODEL_FILE}")
print(f"Model size: {FINAL_MODEL_FILE.stat().st_size / (1024**2):.2f} MB")

In [None]:
def predict_image(image_path, model, processor, id2label, top_k=5):
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors='pt').to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    top_probs, top_indices = torch.topk(probabilities, top_k)
    
    results = []
    for prob, idx in zip(top_probs, top_indices):
        results.append({
            'class': id2label[idx.item()],
            'probability': prob.item()
        })
    
    return results

model.eval()
sample_image = test_paths_subset[0]
predictions = predict_image(sample_image, model, processor, id2label)

print(f"\nSample prediction for: {sample_image}")
print(f"True label: {id2label[test_labels_subset[0]]}\n")
print("Top 5 predictions:")
for i, pred in enumerate(predictions, 1):
    print(f"{i}. {pred['class']}: {pred['probability']:.4f}")

In [None]:
val_accuracy = val_results.get('eval_accuracy', val_results.get('accuracy', 0.0))
val_loss = val_results.get('eval_loss', 0.0)
test_accuracy = test_results.get('eval_accuracy', test_results.get('accuracy', 0.0))
test_loss = test_results.get('eval_loss', 0.0)

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Model: {MODEL_NAME}")
print(f"Dataset: Food-101 ({int(DATA_SUBSET_RATIO*100)}% subset)")
print(f"Number of classes: {len(classes)}")
print(f"Training samples: {len(train_paths_final)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples: {len(test_paths_subset)}")
print(f"\nTraining configuration:")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Learning rate: {LEARNING_RATE}")
print(f"  - LoRA rank: {lora_config.r}")
print(f"  - FP16 training: Enabled")
print(f"\nFinal Results:")
print(f"  - Validation Loss: {val_loss:.4f}")
print(f"  - Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")
print(f"  - Test Loss: {test_loss:.4f}")
print(f"  - Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"\nDeliverables Completed:")
print(f"  ✓ Data preprocessing and augmentation")
print(f"  ✓ Fine-tuning pipeline with LoRA")
print(f"  ✓ Training and validation")
print(f"  ✓ Evaluation (accuracy, confusion matrix, sample predictions)")
print("\nFiles saved:")
print(f"  - Checkpoint: {CHECKPOINT_FILE}")
print(f"  - Final model: {FINAL_MODEL_FILE}")
print(f"  - Confusion matrix: {OUTPUT_DIR / 'confusion_matrix.png'}")
print(f"  - Sample predictions: {OUTPUT_DIR / 'sample_predictions.png'}")
print("="*60)