In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/food-dataset/food41


In [None]:
!pip install transformers datasets accelerate pillow scikit-learn matplotlib seaborn -q

[0m

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
data_path = '/kaggle/input/food41/images'

# Create a list to store image paths and labels
image_data = []

# Traverse the directory structure
for category in os.listdir(data_path):
    category_path = os.path.join(data_path, category)
    if os.path.isdir(category_path):
        for img_name in os.listdir(category_path):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                img_path = os.path.join(category_path, img_name)
                image_data.append({'image_path': img_path, 'label': category})

# Create DataFrame
df = pd.DataFrame(image_data)
print(f"Total images: {len(df)}")
print(f"\nNumber of categories: {df['label'].nunique()}")
print(f"\nClass distribution:\n{df['label'].value_counts()}")

# Create label mappings
label_to_id = {label: idx for idx, label in enumerate(sorted(df['label'].unique()))}
id_to_label = {idx: label for label, idx in label_to_id.items()}
df['label_id'] = df['label'].map(label_to_id)

print(f"\nLabel mappings (first 10):")
for label, idx in list(label_to_id.items())[:10]:
    print(f"{idx}: {label}")

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
categories = df['label'].unique()[:15]

for idx, category in enumerate(categories):
    sample = df[df['label'] == category].iloc[0]
    img = Image.open(sample['image_path'])
    
    row = idx // 5
    col = idx % 5
    axes[row, col].imshow(img)
    axes[row, col].set_title(category, fontsize=10)
    axes[row, col].axis('off')

plt.tight_layout()
plt.savefig('sample_images.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['label_id'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label_id'], random_state=42)

print(f"Train set: {len(train_df)} images")
print(f"Validation set: {len(val_df)} images")
print(f"Test set: {len(test_df)} images")

# Reset indices
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)


In [None]:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)

print(f"Image processor loaded: {model_name}")
print(f"Expected image size: {processor.size}")


In [None]:
class FoodDataset(Dataset):
    def __init__(self, dataframe, processor, augment=False):
        self.dataframe = dataframe
        self.processor = processor
        self.augment = augment
        
        # Define augmentation transforms
        if augment:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            ])
        else:
            self.transform = transforms.Resize((224, 224))
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.dataframe.iloc[idx]['image_path']
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.augment:
            image = self.transform(image)
        else:
            image = self.transform(image)
        
        # Process image using ViT processor
        inputs = self.processor(images=image, return_tensors="pt")
        
        # Get label
        label = self.dataframe.iloc[idx]['label_id']
        
        # Return as dictionary
        return {
            'pixel_values': inputs['pixel_values'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Create datasets
train_dataset = FoodDataset(train_df, processor, augment=True)
val_dataset = FoodDataset(val_df, processor, augment=False)
test_dataset = FoodDataset(test_df, processor, augment=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
num_labels = len(label_to_id)
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    id2label=id_to_label,
    label2id=label_to_id,
    ignore_mismatched_sizes=True
)

model.to(device)
print(f"Model loaded with {num_labels} output classes")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    return {'accuracy': accuracy}

In [None]:
training_args = TrainingArguments(
    output_dir='./vit-food-classifier',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,  # Enable mixed precision training
    gradient_accumulation_steps=2,  # Accumulate gradients to simulate larger batch
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='none',
    dataloader_num_workers=2,
)

print("Training arguments configured:")
print(f"- Epochs: {training_args.num_train_epochs}")
print(f"- Batch size: {training_args.per_device_train_batch_size}")
print(f"- Learning rate: {training_args.learning_rate}")
print(f"- FP16: {training_args.fp16}")
print(f"- Gradient accumulation steps: {training_args.gradient_accumulation_steps}")


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print("Starting training...")
train_result = trainer.train()

# Print training results
print("\nTraining completed!")
print(f"Training loss: {train_result.training_loss:.4f}")
print(f"Training time: {train_result.metrics['train_runtime']:.2f} seconds")

In [None]:
print("\nEvaluating on validation set...")
val_results = trainer.evaluate(val_dataset)
print(f"Validation Accuracy: {val_results['eval_accuracy']:.4f}")
print(f"Validation Loss: {val_results['eval_loss']:.4f}")

In [None]:
print("\nEvaluating on test set...")
test_results = trainer.evaluate(test_dataset)
print(f"Test Accuracy: {test_results['eval_accuracy']:.4f}")
print(f"Test Loss: {test_results['eval_loss']:.4f}")


In [None]:
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

print(f"Generated predictions for {len(y_true)} test images")

# ============================================================================
# CELL 15: Plot Confusion Matrix
# ============================================================================
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(20, 18))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[id_to_label[i] for i in range(num_labels)],
            yticklabels=[id_to_label[i] for i in range(num_labels)],
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Food Classification', fontsize=16, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=90, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
print("\nClassification Report:")
print("=" * 80)
report = classification_report(
    y_true, 
    y_pred, 
    target_names=[id_to_label[i] for i in range(num_labels)],
    digits=4
)
print(report)

In [None]:
class_accuracy = {}
for i in range(num_labels):
    class_mask = (y_true == i)
    if class_mask.sum() > 0:
        class_acc = (y_pred[class_mask] == i).sum() / class_mask.sum()
        class_accuracy[id_to_label[i]] = class_acc

# Sort by accuracy
sorted_classes = sorted(class_accuracy.items(), key=lambda x: x[1], reverse=True)

print("\nPer-Class Accuracy:")
print("=" * 60)
print(f"{'Class':<30} {'Accuracy':<10}")
print("-" * 60)
for class_name, acc in sorted_classes[:10]:
    print(f"{class_name:<30} {acc:.4f}")
print("\n... (showing top 10)")


In [None]:
def predict_image(img_path, model, processor, device):
    image = Image.open(img_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
        pred_class = torch.argmax(probs, dim=-1).item()
        confidence = probs[0][pred_class].item()
    
    return pred_class, confidence, image

# Visualize predictions
fig, axes = plt.subplots(4, 4, figsize=(16, 16))
axes = axes.ravel()

# Sample random images from test set
sample_indices = np.random.choice(len(test_df), 16, replace=False)

for idx, sample_idx in enumerate(sample_indices):
    img_path = test_df.iloc[sample_idx]['image_path']
    true_label = test_df.iloc[sample_idx]['label']
    
    pred_class, confidence, image = predict_image(img_path, model, processor, device)
    pred_label = id_to_label[pred_class]
    
    axes[idx].imshow(image)
    color = 'green' if pred_label == true_label else 'red'
    axes[idx].set_title(f"True: {true_label}\nPred: {pred_label}\nConf: {confidence:.2f}", 
                        color=color, fontsize=9)
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('sample_predictions.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
misclassified_indices = np.where(y_true != y_pred)[0]
print(f"\nTotal misclassifications: {len(misclassified_indices)} out of {len(y_true)}")
print(f"Misclassification rate: {len(misclassified_indices)/len(y_true)*100:.2f}%")

# Analyze common misclassifications
misclass_pairs = {}
for idx in misclassified_indices:
    true_label = id_to_label[y_true[idx]]
    pred_label = id_to_label[y_pred[idx]]
    pair = (true_label, pred_label)
    misclass_pairs[pair] = misclass_pairs.get(pair, 0) + 1

# Sort by frequency
sorted_misclass = sorted(misclass_pairs.items(), key=lambda x: x[1], reverse=True)

print("\nMost common misclassifications:")
print("=" * 80)
print(f"{'True Label':<25} {'Predicted Label':<25} {'Count':<10}")
print("-" * 80)
for (true_label, pred_label), count in sorted_misclass[:15]:
    print(f"{true_label:<25} {pred_label:<25} {count:<10}")

# ============================================================================
# CELL 20: Visualize Misclassified Examples
# ============================================================================
# Visualize some misclassified examples
if len(misclassified_indices) >= 9:
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    axes = axes.ravel()
    
    sample_misclass = np.random.choice(misclassified_indices, 9, replace=False)
    
    for idx, misclass_idx in enumerate(sample_misclass):
        test_idx = test_df.index[misclass_idx]
        img_path = test_df.iloc[misclass_idx]['image_path']
        true_label = id_to_label[y_true[misclass_idx]]
        pred_label = id_to_label[y_pred[misclass_idx]]
        
        _, confidence, image = predict_image(img_path, model, processor, device)
        
        axes[idx].imshow(image)
        axes[idx].set_title(f"True: {true_label}\nPred: {pred_label}\nConf: {confidence:.2f}", 
                           color='red', fontsize=9)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('misclassifications.png', dpi=100, bbox_inches='tight')
    plt.show()

# ============================================================================
# CELL 21: Save the Model
# ============================================================================
# Save the fine-tuned model
model.save_pretrained('./vit-food-classifier-final')
processor.save_pretrained('./vit-food-classifier-final')

print("\nModel and processor saved to './vit-food-classifier-final'")
print("You can load it later with:")
print("model = ViTForImageClassification.from_pretrained('./vit-food-classifier-final')")
print("processor = ViTImageProcessor.from_pretrained('./vit-food-classifier-final')")

# ============================================================================
# CELL 22: Final Summary
# ============================================================================
print("\n" + "="*80)
print("FINAL RESULTS SUMMARY")
print("="*80)
print(f"Model: {model_name}")
print(f"Dataset: Food-41")
print(f"Number of classes: {num_labels}")
print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")
print(f"\nTest Accuracy: {test_results['eval_accuracy']:.4f}")
print(f"Test Loss: {test_results['eval_loss']:.4f}")
print(f"Misclassification rate: {len(misclassified_indices)/len(y_true)*100:.2f}%")
print("="*80)


def classify_food_image(image_path, model, processor, device, id_to_label, top_k=5):
    """
    Classify a food image and return top-k predictions.
    
    Args:
        image_path: Path to the image
        model: Trained ViT model
        processor: ViT image processor
        device: Device (cuda/cpu)
        id_to_label: Dictionary mapping class IDs to labels
        top_k: Number of top predictions to return
    
    Returns:
        Dictionary with predictions and confidences
    """
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # Get top-k predictions
    top_probs, top_indices = torch.topk(probs[0], top_k)
    
    results = {
        'predictions': [
            {
                'label': id_to_label[idx.item()],
                'confidence': prob.item()
            }
            for prob, idx in zip(top_probs, top_indices)
        ],
        'image': image
    }
    
    return results

print("Inference function defined: classify_food_image()")
print("Usage example:")
print("results = classify_food_image('path/to/image.jpg', model, processor, device, id_to_label)")
print("for pred in results['predictions']:")
print("    print(f\"{pred['label']}: {pred['confidence']:.4f}\")")