In [None]:


# ============================================================================
# Install Required Packages
# ============================================================================
!pip install -q transformers datasets accelerate evaluate

# ============================================================================
# Import Libraries
# ============================================================================
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.transforms import functional as F

from transformers import (
    ViTImageProcessor, 
    ViTForImageClassification,
    TrainingArguments,
    Trainer
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ============================================================================
# CELL 3: Configuration
# ============================================================================
class Config:
    DATA_DIR = '/kaggle/input/food41/images'
    OUTPUT_DIR = '/kaggle/working/vit_food41_model'
    
    MODEL_NAME = 'google/vit-base-patch16-224'
    NUM_LABELS = 101   # ✅ Fix here
    
    BATCH_SIZE = 32
    NUM_EPOCHS = 5
    LEARNING_RATE = 2e-5
    WARMUP_STEPS = 500
    WEIGHT_DECAY = 0.01
    
    FP16 = True
    GRADIENT_ACCUMULATION_STEPS = 2
    
    VAL_SIZE = 0.1
    TEST_SIZE = 0.1
    
    SEED = 42
    NUM_WORKERS = 2


config = Config()

# Set seeds
torch.manual_seed(config.SEED)
np.random.seed(config.SEED)

# ============================================================================
# Data Exploration
# ============================================================================
print("Exploring dataset...")
print(f"Dataset location: {config.DATA_DIR}")

if not os.path.exists(config.DATA_DIR):
    print(f"ERROR: Dataset not found at {config.DATA_DIR}")
    print("Please add the Food-41 dataset to your Kaggle notebook:")
    print("1. Click 'Add Data' in the right sidebar")
    print("2. Search for 'food41'")
    print("3. Add the dataset")
else:
    categories = sorted([d for d in os.listdir(config.DATA_DIR) 
                        if os.path.isdir(os.path.join(config.DATA_DIR, d))])
    
    print(f"\nFound {len(categories)} food categories")
    print(f"Sample categories: {categories[:5]}")
    
    # Count images per category
    category_counts = {}
    for cat in categories:
        cat_path = os.path.join(config.DATA_DIR, cat)
        count = len([f for f in os.listdir(cat_path) 
                    if f.endswith(('.jpg', '.jpeg', '.png'))])
        category_counts[cat] = count
    
    print(f"\nTotal images: {sum(category_counts.values())}")
    print(f"Average per category: {np.mean(list(category_counts.values())):.0f}")
    print(f"Min per category: {min(category_counts.values())}")
    print(f"Max per category: {max(category_counts.values())}")

# ============================================================================
# Visualize Dataset Distribution
# ============================================================================
# Plot category distribution
plt.figure(figsize=(20, 6))
sorted_cats = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)
cats, counts = zip(*sorted_cats)

plt.bar(range(len(cats)), counts, color='steelblue', alpha=0.7)
plt.axhline(y=np.mean(counts), color='r', linestyle='--', 
            label=f'Average: {np.mean(counts):.0f}')
plt.xlabel('Food Category', fontsize=12, fontweight='bold')
plt.ylabel('Number of Images', fontsize=12, fontweight='bold')
plt.title('Food-41 Dataset Distribution', fontsize=14, fontweight='bold')
plt.xticks(range(len(cats)), cats, rotation=90, ha='right')
plt.legend()
plt.tight_layout()
plt.savefig('dataset_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print("Distribution plot saved!")

# ============================================================================
# Visualize Sample Images
# ============================================================================
# Show sample images from different categories
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
axes = axes.flatten()

sample_categories = np.random.choice(categories, size=min(15, len(categories)), 
                                    replace=False)

for idx, cat in enumerate(sample_categories):
    cat_path = os.path.join(config.DATA_DIR, cat)
    images = [f for f in os.listdir(cat_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
    
    if images:
        img_file = np.random.choice(images)
        img_path = os.path.join(cat_path, img_file)
        
        try:
            img = Image.open(img_path).convert('RGB')
            axes[idx].imshow(img)
            axes[idx].set_title(cat.replace('_', ' ').title(), 
                               fontsize=9, fontweight='bold')
            axes[idx].axis('off')
        except:
            axes[idx].axis('off')

plt.tight_layout()
plt.savefig('sample_images.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================================
# Prepare Data Splits
# ============================================================================
print("\nPreparing train/val/test splits...")

# Collect all image paths and labels
all_images = []
all_labels = []
label_to_idx = {}

for idx, category in enumerate(categories):
    label_to_idx[category] = idx
    category_path = os.path.join(config.DATA_DIR, category)
    
    for img_file in os.listdir(category_path):
        if img_file.endswith(('.jpg', '.jpeg', '.png')):
            img_path = os.path.join(category_path, img_file)
            all_images.append(img_path)
            all_labels.append(idx)

print(f"Total images collected: {len(all_images)}")

# Create stratified splits
train_val_images, test_images, train_val_labels, test_labels = train_test_split(
    all_images, all_labels, 
    test_size=config.TEST_SIZE, 
    random_state=config.SEED,
    stratify=all_labels
)

train_images, val_images, train_labels, val_labels = train_test_split(
    train_val_images, train_val_labels,
    test_size=config.VAL_SIZE / (1 - config.TEST_SIZE),
    random_state=config.SEED,
    stratify=train_val_labels
)

print(f"Train: {len(train_images)} ({len(train_images)/len(all_images)*100:.1f}%)")
print(f"Val: {len(val_images)} ({len(val_images)/len(all_images)*100:.1f}%)")
print(f"Test: {len(test_images)} ({len(test_images)/len(all_images)*100:.1f}%)")

# Save label mapping
idx_to_label = {v: k for k, v in label_to_idx.items()}
with open('label_mapping.json', 'w') as f:
    json.dump(idx_to_label, f, indent=2)

print("\nLabel mapping saved!")

# ============================================================================
# Custom Dataset Class
# ============================================================================
class Food41Dataset(Dataset):
    """Custom Dataset for Food-41"""
    
    def __init__(self, image_paths, labels, processor, is_train=False):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.is_train = is_train
        
        # Augmentation transforms
        if is_train:
            self.aug_transforms = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(15),
                transforms.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.2,
                    hue=0.1
                ),
            ])
        else:
            self.aug_transforms = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
            ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')
        except Exception as e:
            print(f"Error loading {self.image_paths[idx]}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (224, 224), color='white')
        
        # Apply augmentation
        image = self.aug_transforms(image)
        
        # Process with ViT processor
        encoding = self.processor(images=image, return_tensors="pt")
        
        # Remove batch dimension
        pixel_values = encoding['pixel_values'].squeeze(0)
        
        return {
            'pixel_values': pixel_values,
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

print("Dataset class defined!")

# ============================================================================
# Load Model and Processor
# ============================================================================
print("Loading ViT model and processor...")

processor = ViTImageProcessor.from_pretrained(config.MODEL_NAME)
model = ViTForImageClassification.from_pretrained(
    config.MODEL_NAME,
    num_labels=config.NUM_LABELS,
    ignore_mismatched_sizes=True
)

# Update model config
model.config.id2label = idx_to_label
model.config.label2id = label_to_idx

print(f"Model loaded: {config.MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# ============================================================================
# Create Datasets
# ============================================================================
print("\nCreating datasets...")

train_dataset = Food41Dataset(train_images, train_labels, processor, is_train=True)
val_dataset = Food41Dataset(val_images, val_labels, processor, is_train=False)
test_dataset = Food41Dataset(test_images, test_labels, processor, is_train=False)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

# Test dataset loading
sample = train_dataset[0]
print(f"\nSample data shape: {sample['pixel_values'].shape}")
print(f"Sample label: {sample['labels']}")

# ============================================================================
# Define Metrics
# ============================================================================
def compute_metrics(eval_pred):
    """Compute accuracy metrics"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    accuracy = accuracy_score(labels, predictions)
    
    return {
        'accuracy': accuracy,
    }

print("Metrics function defined!")

# ============================================================================
# Training Arguments
# ============================================================================
training_args = TrainingArguments(
    output_dir=config.OUTPUT_DIR,
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE,
    learning_rate=config.LEARNING_RATE,
    warmup_steps=config.WARMUP_STEPS,
    weight_decay=config.WEIGHT_DECAY,
    logging_dir='./logs',
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    fp16=config.FP16,
    gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
    report_to="none",
    save_total_limit=2,
    dataloader_num_workers=config.NUM_WORKERS,
)

print("Training arguments configured!")
print(f"Effective batch size: {config.BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}")

# ============================================================================
# Initialize Trainer
# ============================================================================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print("Trainer initialized!")
print("Ready to start training...")

# ============================================================================
# Train Model
# ============================================================================
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60 + "\n")

# Train the model
train_result = trainer.train()

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"\nTraining metrics:")
print(f"  Final loss: {train_result.training_loss:.4f}")

# ============================================================================
# Save Model
# ============================================================================
print("\nSaving model...")

# Save model and processor
trainer.save_model(config.OUTPUT_DIR)
processor.save_pretrained(config.OUTPUT_DIR)

# Save label mapping to output dir
with open(os.path.join(config.OUTPUT_DIR, 'label_mapping.json'), 'w') as f:
    json.dump(idx_to_label, f, indent=2)

print(f"Model saved to: {config.OUTPUT_DIR}")

# ============================================================================
# Evaluate on Validation Set
# ============================================================================
print("\n" + "="*60)
print("EVALUATING ON VALIDATION SET")
print("="*60 + "\n")

val_results = trainer.evaluate(val_dataset)
print("Validation Results:")
for key, value in val_results.items():
    print(f"  {key}: {value:.4f}")

# ============================================================================
# Evaluate on Test Set
# ============================================================================
print("\n" + "="*60)
print("EVALUATING ON TEST SET")
print("="*60 + "\n")

# Get predictions
predictions = trainer.predict(test_dataset)
pred_labels = np.argmax(predictions.predictions, axis=1)
true_labels = test_labels

# Calculate accuracy
test_accuracy = accuracy_score(true_labels, pred_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")

# Calculate top-k accuracy
def top_k_accuracy(y_true, y_pred_probs, k=5):
    top_k_preds = np.argsort(y_pred_probs, axis=1)[:, -k:]
    return np.mean([y_true[i] in top_k_preds[i] for i in range(len(y_true))])

top3_acc = top_k_accuracy(true_labels, predictions.predictions, k=3)
top5_acc = top_k_accuracy(true_labels, predictions.predictions, k=5)

print(f"Top-3 Accuracy: {top3_acc:.4f}")
print(f"Top-5 Accuracy: {top5_acc:.4f}")

# ============================================================================
# Classification Report
# ============================================================================
print("\n" + "="*60)
print("DETAILED CLASSIFICATION REPORT")
print("="*60 + "\n")

# Get category names
category_names = [idx_to_label[i] for i in range(len(categories))]

# Generate classification report
report = classification_report(
    true_labels, 
    pred_labels,
    target_names=category_names,
    output_dict=True
)

# Convert to DataFrame and save
report_df = pd.DataFrame(report).transpose()
report_df.to_csv('classification_report.csv')
print("Classification report saved to: classification_report.csv")

# Display top and bottom performing classes
report_df_sorted = report_df.sort_values('f1-score', ascending=False)
print("\nTop 10 Best Performing Classes:")
print(report_df_sorted.head(10)[['precision', 'recall', 'f1-score', 'support']])

print("\nBottom 10 Classes:")
print(report_df_sorted.tail(10)[['precision', 'recall', 'f1-score', 'support']])

# ============================================================================
# Confusion Matrix
# ============================================================================
print("\nGenerating confusion matrix...")

# Calculate confusion matrix
cm = confusion_matrix(true_labels, pred_labels)

# Plot confusion matrix
plt.figure(figsize=(20, 18))
sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', 
            xticklabels=category_names, 
            yticklabels=category_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - ViT Food-41 Classification', 
          fontsize=16, fontweight='bold', pad=20)
plt.ylabel('True Label', fontsize=12, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=12, fontweight='bold')
plt.xticks(rotation=90, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("Confusion matrix saved!")

# Analyze major confusions
print("\nTop 10 Most Confused Class Pairs:")
confusions = []
for i in range(len(cm)):
    for j in range(len(cm)):
        if i != j and cm[i][j] > 0:
            confusions.append((
                category_names[i], 
                category_names[j], 
                cm[i][j]
            ))

confusions.sort(key=lambda x: x[2], reverse=True)
for true_class, pred_class, count in confusions[:10]:
    print(f"  {true_class} → {pred_class}: {count} times")

# ============================================================================
# Sample Predictions
# ============================================================================
print("\n" + "="*60)
print("SAMPLE PREDICTIONS")
print("="*60 + "\n")

# Get some test samples
num_samples = 6
sample_indices = np.random.choice(len(test_images), num_samples, replace=False)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, sample_idx in enumerate(sample_indices):
    # Load image
    img_path = test_images[sample_idx]
    image = Image.open(img_path).convert('RGB')
    true_label = idx_to_label[test_labels[sample_idx]]
    
    # Get prediction
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        pred_idx = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred_idx].item()
        pred_label = idx_to_label[pred_idx]
    
    # Plot
    axes[idx].imshow(image)
    color = 'green' if pred_label == true_label else 'red'
    title = f"True: {true_label}\nPred: {pred_label}\nConf: {confidence:.2%}"
    axes[idx].set_title(title, fontsize=10, color=color, fontweight='bold')
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('sample_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print("Sample predictions saved!")

# ============================================================================
# Save Results Summary
# ============================================================================
# Create results summary
results_summary = {
    'model_name': config.MODEL_NAME,
    'num_epochs': config.NUM_EPOCHS,
    'batch_size': config.BATCH_SIZE,
    'learning_rate': config.LEARNING_RATE,
    'test_accuracy': float(test_accuracy),
    'top3_accuracy': float(top3_acc),
    'top5_accuracy': float(top5_acc),
    'num_classes': len(categories),
    'train_samples': len(train_images),
    'val_samples': len(val_images),
    'test_samples': len(test_images),
}

with open('results_summary.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print("\nResults Summary:")
print(json.dumps(results_summary, indent=2))
print("\nResults saved to: results_summary.json")

# ============================================================================
# Single Image Prediction Function
# ============================================================================
def predict_single_image(image_path, top_k=5):
    """Predict a single image with top-k results"""
    
    # Load image
    image = Image.open(image_path).convert('RGB')
    
    # Preprocess
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    
    # Get top-k predictions
    top_probs, top_indices = torch.topk(probs, k=min(top_k, config.NUM_LABELS))
    
    results = []
    for prob, idx in zip(top_probs[0], top_indices[0]):
        results.append({
            'label': idx_to_label[idx.item()],
            'confidence': prob.item()
        })
    
    return results

# Test the function
print("\nTesting prediction function with a random test image...")
test_img_path = test_images[0]
predictions = predict_single_image(test_img_path)

print(f"Image: {test_img_path}")
print("Predictions:")
for i, pred in enumerate(predictions, 1):
    print(f"  {i}. {pred['label']}: {pred['confidence']:.2%}")

print("\n" + "="*60)
print("ALL TASKS COMPLETE!")
print("="*60)
print("\nGenerated files:")
print("  - vit_food41_model/ (trained model)")
print("  - label_mapping.json")
print("  - classification_report.csv")
print("  - confusion_matrix.png")
print("  - sample_predictions.png")
print("  - dataset_distribution.png")
print("  - results_summary.json")