In [None]:
# ============================================================
# TOP 10 POPULAR SNEAKERS CLASSIFIER - GOOGLE COLAB VERSION
# 10 Most Iconic Sneaker Models
# ============================================================

# ===== STEP 1: SETUP KAGGLE AND DOWNLOAD DATASET =====
print("="*60)
print("STEP 1: KAGGLE SETUP & DATASET DOWNLOAD")
print("="*60)

# Install kaggle
!pip install -q kaggle

# Upload kaggle.json
print("\n‚ö†Ô∏è  UPLOAD YOUR kaggle.json FILE NOW!")
print("    Get it from: https://www.kaggle.com/settings -> Create New API Token\n")
from google.colab import files
uploaded = files.upload()

# Setup Kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download dataset
print("\nüì• Downloading sneakers dataset...")
!kaggle datasets download -d nikolasgegenava/sneakers-classification

# Unzip the dataset
print("\nüì¶ Unzipping dataset...")
!unzip -o sneakers-classification.zip -d sneakers_data
print("‚úÖ Dataset extracted to sneakers_data/\n")

# Show what's inside
print("üìÇ Contents of sneakers_data:")
!ls sneakers_data/

print("\n‚úÖ Download complete!\n")

In [None]:


# ===== STEP 2: SELECT TOP 10 SNEAKERS & ORGANIZE =====
import os
import shutil
from sklearn.model_selection import train_test_split
from collections import defaultdict

print("="*60)
print("STEP 2: FILTERING TOP 10 SNEAKERS")
print("="*60)

# Define the top 10 most popular sneakers we want to classify
TOP_10_SNEAKERS = [
    "Nike Air Force 1 Low",
    "Nike Air Jordan 1 High",
    "Nike Dunk Low",
    "Adidas Stan Smith",
    "Adidas Superstar",
    "Converse Chuck Taylor All-Star High",
    "Vans Old Skool",
    "New Balance 550",
    "Yeezy Boost 350 V2",
    "Nike Air Max 90"
]

print("\nüéØ Selected Top 10 Sneakers:")
for i, sneaker in enumerate(TOP_10_SNEAKERS, 1):
    print(f"{i:2d}. {sneaker}")

# Base dataset folder
base_dir = "sneakers_data/sneakers-dataset/sneakers-dataset"

# First, let's see ALL folders in the dataset
print(f"\nüìÅ Scanning {base_dir} folder...")
all_folders = [f for f in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, f))]
print(f"Found {len(all_folders)} total folders in dataset")
if len(all_folders) > 0:
    print(f"Sample folders: {all_folders[:5]}")

# Find and collect only the top 10 categories
categories = []
all_images = defaultdict(list)

for item in os.listdir(base_dir):
    # Only process if it's one of our top 10
    if item.replace('_', ' ').title() in TOP_10_SNEAKERS: # Normalize folder names for comparison
        item_path = os.path.join(base_dir, item)
        if os.path.isdir(item_path):
            images = [os.path.join(item_path, f) for f in os.listdir(item_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))]
            if len(images) > 0:
                # Map the dataset folder name to our standard sneaker name
                standard_name = item.replace('_', ' ').title()
                categories.append(standard_name)
                all_images[standard_name] = images

# Sort categories to match our TOP_10 order
categories = [cat for cat in TOP_10_SNEAKERS if cat in categories]

print(f"\n‚úÖ Found {len(categories)} categories from our top 10 list:")
total_images = 0
for i, cat in enumerate(categories, 1):
    count = len(all_images[cat])
    total_images += count
    print(f"{i:2d}. {cat}: {count} images")
print(f"\nTotal images: {total_images}")

if len(categories) == 0:
    print("\n‚ùå ERROR: No matching categories found!")
    print("\nLet me show you what folders exist:")
    print("All folders in sneakers_data/sneakers-dataset/sneakers-dataset:")
    for folder in all_folders[:20]:
        print(f"  - {folder}")
    print("\nPlease check if folder names match exactly with TOP_10_SNEAKERS list")
    exit()

# Split into train/validation (80/20) for each category
train_images = {}
val_images = {}

for category in categories:
    images = all_images[category]
    if len(images) >= 2:
        train, val = train_test_split(images, test_size=0.2, random_state=42)
        train_images[category] = train
        val_images[category] = val
    else:
        train_images[category] = images
        val_images[category] = []

# Create new directory structure
train_base = "dataset/train"
val_base = "dataset/validation"

for category in categories:
    os.makedirs(os.path.join(train_base, category), exist_ok=True)
    os.makedirs(os.path.join(val_base, category), exist_ok=True)

# Copy files to train
print("\nCopying training images...")
for category in categories:
    for img in train_images[category]:
        shutil.copy(img, os.path.join(train_base, category))

# Copy files to validation
print("Copying validation images...")
for category in categories:
    for img in val_images[category]:
        shutil.copy(img, os.path.join(val_base, category))

train_total = sum(len(train_images[cat]) for cat in categories)
val_total = sum(len(val_images[cat]) for cat in categories)

print(f"\nüìä Dataset Split:")
print(f"   Train set: {train_total} images")
print(f"   Validation set: {val_total} images")
print("\nDataset structure created successfully!\n")

# ===== STEP 3: SETUP MODEL AND TRAINING =====
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
from tqdm.auto import tqdm

print("="*60)
print("STEP 3: MODEL SETUP")
print("="*60)
print(f"\nGPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

# Custom Dataset class
class SneakerDataset(Dataset):
    def __init__(self, root_dir, processor, categories):
        self.processor = processor
        self.images = []
        self.labels = []

        # Create label mapping
        self.label_map = {cat: idx for idx, cat in enumerate(categories)}

        # Load all images from each category
        for category in categories:
            cat_dir = os.path.join(root_dir, category)
            if os.path.exists(cat_dir):
                for img_name in os.listdir(cat_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                        self.images.append(os.path.join(cat_dir, img_name))
                        self.labels.append(self.label_map[category])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.images[idx]).convert('RGB')
            label = self.labels[idx]

            # Process image
            inputs = self.processor(images=image, return_tensors="pt")
            pixel_values = inputs['pixel_values'].squeeze()

            return {'pixel_values': pixel_values, 'labels': torch.tensor(label)}
        except:
            # Skip corrupted images
            return self.__getitem__((idx + 1) % len(self))

# Load processor and model
print("\nLoading model...")
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

# Create label mappings
label2id = {cat: idx for idx, cat in enumerate(categories)}
id2label = {idx: cat for idx, cat in enumerate(categories)}

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=len(categories),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Create datasets
print("Loading datasets...")
train_dataset = SneakerDataset(train_base, processor, categories)
val_dataset = SneakerDataset(val_base, processor, categories)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Create dataloaders
batch_size = 16 if torch.cuda.is_available() else 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 5

print(f"\nTraining Configuration:")
print(f"   Batch size: {batch_size}")
print(f"   Epochs: {num_epochs}")
print(f"   Learning rate: 2e-5")
print(f"   Number of classes: {len(categories)}\n")

# ===== STEP 4: TRAINING =====
print("="*60)
print("STEP 4: TRAINING")
print("="*60)
print("\nStarting training...\n")

best_val_acc = 0.0

for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for batch in progress_bar:
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        predictions = outputs.logits.argmax(dim=-1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.2f}%'})

    avg_train_loss = train_loss / len(train_loader)
    train_acc = 100 * correct / total

    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            val_loss += loss.item()
            predictions = outputs.logits.argmax(dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.2f}%'})

    avg_val_loss = val_loss / len(val_loader)
    val_acc = 100 * correct / total

    print(f'\nEpoch {epoch+1} Summary:')
    print(f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f'üéØ New best accuracy! Saving model...')
        model.save_pretrained("./sneakers_top10_best")
        processor.save_pretrained("./sneakers_top10_best")
    print()

# Save final model
print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print("Saving final model...")
model.save_pretrained("./sneakers_top10_final")
processor.save_pretrained("./sneakers_top10_final")
print(f"\n‚úÖ Best validation accuracy: {best_val_acc:.2f}%")
print("‚úÖ Models saved to:")
print("   - sneakers_top10_best (highest accuracy)")
print("   - sneakers_top10_final (last epoch)")



In [None]:
# ============================================================
# MODEL TRAINING ANALYSIS & VISUALIZATION
# Run this after training completes
# ============================================================

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd
from tqdm.auto import tqdm

print("="*60)
print("MODEL PERFORMANCE ANALYSIS")
print("="*60)

# Load the trained model
print("\nüì¶ Loading trained model...")
from transformers import ViTForImageClassification, ViTImageProcessor

model = ViTForImageClassification.from_pretrained("./sneakers_top10_best")
processor = ViTImageProcessor.from_pretrained("./sneakers_top10_best")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

# Get class names
class_names = [model.config.id2label[i] for i in range(len(model.config.id2label))]
print(f"‚úÖ Model loaded! Classes: {len(class_names)}")

# ===== 1. CONFUSION MATRIX =====
print("\n1Ô∏è‚É£ Generating Confusion Matrix...")

# Get predictions on validation set
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(pixel_values=pixel_values)
        predictions = outputs.logits.argmax(dim=-1)
        
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Create confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[name[:20] for name in class_names],
            yticklabels=[name[:20] for name in class_names],
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Validation Set', fontsize=16, fontweight='bold', pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# ===== 2. PER-CLASS ACCURACY =====
print("\n2Ô∏è‚É£ Per-Class Performance...")

# Calculate per-class metrics
class_correct = np.diag(cm)
class_total = cm.sum(axis=1)
class_accuracy = (class_correct / class_total) * 100

# Create DataFrame
metrics_df = pd.DataFrame({
    'Sneaker': class_names,
    'Correct': class_correct,
    'Total': class_total,
    'Accuracy (%)': class_accuracy
}).sort_values('Accuracy (%)', ascending=True)

# Plot per-class accuracy
plt.figure(figsize=(12, 8))
colors = ['#FF6B6B' if acc < 70 else '#FFA07A' if acc < 85 else '#4ECDC4' for acc in metrics_df['Accuracy (%)']]
bars = plt.barh(range(len(metrics_df)), metrics_df['Accuracy (%)'], color=colors)
plt.yticks(range(len(metrics_df)), [name[:30] for name in metrics_df['Sneaker']])
plt.xlabel('Accuracy (%)', fontsize=12)
plt.title('Per-Class Accuracy', fontsize=16, fontweight='bold', pad=20)
plt.xlim(0, 100)
plt.axvline(x=80, color='gray', linestyle='--', alpha=0.5, label='60% threshold')
plt.legend()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

# Print detailed metrics
print("\nüìä Detailed Per-Class Metrics:")
print(metrics_df.to_string(index=False))

# ===== 3. CLASSIFICATION REPORT =====
print("\n3Ô∏è‚É£ Classification Report...")
print("\n" + classification_report(all_labels, all_preds, target_names=class_names, digits=3))

# ===== 4. TOP CONFUSED PAIRS =====
print("\n4Ô∏è‚É£ Most Confused Sneaker Pairs...")

# Find off-diagonal maximums (most confused pairs)
np.fill_diagonal(cm, 0)
confused_pairs = []
for i in range(len(cm)):
    for j in range(len(cm)):
        if cm[i][j] > 0:
            confused_pairs.append({
                'True': class_names[i],
                'Predicted': class_names[j],
                'Count': cm[i][j]
            })

confused_df = pd.DataFrame(confused_pairs).sort_values('Count', ascending=False).head(10)

plt.figure(figsize=(14, 6))
x_labels = [f"{row['True'][:15]}\n‚Üí {row['Predicted'][:15]}" for _, row in confused_df.iterrows()]
plt.bar(range(len(confused_df)), confused_df['Count'], color='#FF6B6B')
plt.xticks(range(len(confused_df)), x_labels, rotation=45, ha='right')
plt.ylabel('Number of Misclassifications', fontsize=12)
plt.title('Top 10 Most Confused Sneaker Pairs', fontsize=16, fontweight='bold', pad=20)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

print("\nTop Confused Pairs:")
print(confused_df.to_string(index=False))

# ===== 5. CONFIDENCE DISTRIBUTION =====
print("\n5Ô∏è‚É£ Model Confidence Distribution...")

# Get confidence scores for all predictions
confidences = []
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Analyzing confidence"):
        pixel_values = batch['pixel_values'].to(device)
        outputs = model(pixel_values=pixel_values)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        max_probs = probs.max(dim=-1)[0]
        confidences.extend(max_probs.cpu().numpy())

confidences = np.array(confidences) * 100

plt.figure(figsize=(12, 6))
plt.hist(confidences, bins=50, color='#4ECDC4', edgecolor='black', alpha=0.7)
plt.axvline(confidences.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {confidences.mean():.1f}%')
plt.axvline(np.median(confidences), color='orange', linestyle='--', linewidth=2, label=f'Median: {np.median(confidences):.1f}%')
plt.xlabel('Prediction Confidence (%)', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Model Confidence Scores', fontsize=16, fontweight='bold', pad=20)
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nConfidence Statistics:")
print(f"   Mean Confidence: {confidences.mean():.2f}%")
print(f"   Median Confidence: {np.median(confidences):.2f}%")
print(f"   Min Confidence: {confidences.min():.2f}%")
print(f"   Max Confidence: {confidences.max():.2f}%")
print(f"   Std Deviation: {confidences.std():.2f}%")

# ===== 6. OVERALL SUMMARY =====
print("\n" + "="*60)
print("üìà OVERALL MODEL SUMMARY")
print("="*60)

overall_accuracy = (sum(class_correct) / sum(class_total)) * 100
print(f"\n‚úÖ Overall Validation Accuracy: {overall_accuracy:.2f}%")
print(f"üìä Total Predictions: {len(all_preds)}")
print(f"‚úì Correct: {sum(class_correct)}")
print(f"‚úó Incorrect: {len(all_preds) - sum(class_correct)}")
print(f"\nüèÜ Best Performing Class: {metrics_df.iloc[-1]['Sneaker']} ({metrics_df.iloc[-1]['Accuracy (%)']:.2f}%)")
print(f"‚ö†Ô∏è  Worst Performing Class: {metrics_df.iloc[0]['Sneaker']} ({metrics_df.iloc[0]['Accuracy (%)']:.2f}%)")

print("\n" + "="*60)
print("‚úÖ ANALYSIS COMPLETE!")
print("="*60)

In [None]:
# ============================================================
# GRADCAM & ATTENTION VISUALIZATION
# See what parts of the image the model focuses on!
# ============================================================

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from google.colab import files
import cv2

print("="*60)
print("VISUAL EXPLANATION - WHY THIS PREDICTION?")
print("="*60)

# ===== UPLOAD IMAGE =====
print("\nüì§ Upload a sneaker image to analyze:\n")
uploaded = files.upload()
filename = list(uploaded.keys())[0]

# Load image
original_image = Image.open(filename).convert('RGB')
print(f"‚úÖ Loaded: {filename}")

# ===== SETUP MODEL =====
from transformers import ViTForImageClassification, ViTImageProcessor

# Ensure output_attentions=True when loading the model to get attention weights
model = ViTForImageClassification.from_pretrained("./sneakers_top10_best", output_attentions=True)
processor = ViTImageProcessor.from_pretrained("./sneakers_top10_best")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

# Process image
inputs = processor(images=original_image, return_tensors="pt").to(device)

# ===== 1. GET PREDICTION =====
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probs = F.softmax(logits, dim=-1)
    pred_class = logits.argmax(-1).item()
    confidence = probs[0][pred_class].item() * 100

predicted_sneaker = model.config.id2label[pred_class]
print(f"\nüéØ Prediction: {predicted_sneaker} ({confidence:.2f}% confidence)")

# ===== 2. ATTENTION ROLLOUT =====
print("\nüîç Generating attention visualization...")

# Get attention weights from all layers
with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions  # Tuple of attention weights from each layer

# Attention rollout - combine attention from all layers
def attention_rollout(attentions, discard_ratio=0.9):
    # Initialize the accumulated attention matrix with identity for residual connections
    # Number of tokens = num_patches + 1 (for CLS token)
    num_tokens = attentions[0].size(-1)
    result = torch.eye(num_tokens, device=device)

    # Process each layer in reverse order (from last layer to first)
    for attention_layer in reversed(attentions):
        # Average over all heads
        # attention_layer.shape: (batch_size, num_heads, sequence_length, sequence_length)
        attention_heads_fused = attention_layer.mean(dim=1)[0] # (sequence_length, sequence_length)

        # Add identity matrix to account for residual connections
        # This makes the attention flow through non-attended parts as well
        current_attention = attention_heads_fused + torch.eye(num_tokens, device=device)

        # Normalize each row to sum to 1 to represent probability distribution
        # Add a small epsilon to the denominator to prevent division by zero
        current_attention = current_attention / (current_attention.sum(dim=-1, keepdim=True) + 1e-12)

        # Multiply with the accumulated result (matrix multiplication for attention propagation)
        result = torch.matmul(current_attention, result)

    # Get attention for class token (index 0) to other tokens (index 1 onwards)
    mask = result[0, 1:]  # Exclude class token to class token attention

    # Apply discard ratio to the final mask if desired
    if discard_ratio > 0:
        flat_mask = mask.view(-1)
        # Only discard if there are enough elements to discard and sum is not already zero
        if int(flat_mask.size(-1) * discard_ratio) < flat_mask.size(-1):
            _, indices = flat_mask.topk(int(flat_mask.size(-1) * discard_ratio), largest=False)
            flat_mask.scatter_(-1, indices, 0)
        mask = flat_mask # Update mask to be flat after scattering
    
    # Re-normalize the final mask if changes were made, ensuring it sums to 1
    mask = mask / (mask.sum() + 1e-12)

    return mask

# Get attention mask
attention_mask = attention_rollout(attentions).cpu().numpy()

# Reshape to 2D (ViT uses 14x14 patches for 224x224 images)
num_patches = int(np.sqrt(len(attention_mask)))
attention_map = attention_mask.reshape(num_patches, num_patches)

# ===== 3. GRADIENT-BASED SALIENCY MAP =====
print("üî• Generating gradient saliency map...")

# Enable gradients
model.zero_grad()
inputs_grad = processor(images=original_image, return_tensors="pt").to(device)
inputs_grad['pixel_values'].requires_grad = True

# Forward pass
outputs = model(**inputs_grad)
target_score = outputs.logits[0, pred_class]

# Backward pass
target_score.backward()

# Get gradients
gradients = inputs_grad['pixel_values'].grad.data[0]
gradients = gradients.cpu().numpy()

# Calculate saliency (absolute value of gradients)
saliency = np.abs(gradients).max(axis=0)
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-12) # Add epsilon here too

# ===== 4. CREATE VISUALIZATIONS =====
print("üé® Creating visualizations...")

# Prepare original image as numpy array
img_array = np.array(original_image.resize((224, 224)))

# Create figure with multiple subplots
fig = plt.figure(figsize=(20, 10))

# 1. Original Image
ax1 = plt.subplot(2, 4, 1)
ax1.imshow(original_image)
ax1.set_title('Original Image', fontsize=14, fontweight='bold')
ax1.axis('off')

# 2. Prediction Info
ax2 = plt.subplot(2, 4, 2)
ax2.axis('off')
info_text = f"Prediction:\n{predicted_sneaker}\n\nConfidence:\n{confidence:.2f}%"
ax2.text(0.5, 0.5, info_text, ha='center', va='center',
         fontsize=16, fontweight='bold',
         bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

# 3. Attention Heatmap (Raw)
ax3 = plt.subplot(2, 4, 3)
im3 = ax3.imshow(attention_map, cmap='hot', interpolation='bilinear')
ax3.set_title('Attention Map (Raw)', fontsize=14, fontweight='bold')
ax3.axis('off')
plt.colorbar(im3, ax=ax3, fraction=0.046)

# 4. Attention Overlay
ax4 = plt.subplot(2, 4, 4)
attention_resized = cv2.resize(attention_map, (224, 224))
attention_resized = (attention_resized - attention_resized.min()) / (attention_resized.max() - attention_resized.min() + 1e-12) # Add epsilon
ax4.imshow(img_array)
ax4.imshow(attention_resized, cmap='hot', alpha=0.5)
ax4.set_title('Attention Overlay', fontsize=14, fontweight='bold')
ax4.axis('off')

# 5. Gradient Saliency (Raw)
ax5 = plt.subplot(2, 4, 5)
im5 = ax5.imshow(saliency, cmap='hot')
ax5.set_title('Gradient Saliency Map', fontsize=14, fontweight='bold')
ax5.axis('off')
plt.colorbar(im5, ax=ax5, fraction=0.046)

# 6. Gradient Overlay
ax6 = plt.subplot(2, 4, 6)
ax6.imshow(img_array)
ax6.imshow(saliency, cmap='hot', alpha=0.5)
ax6.set_title('Gradient Overlay', fontsize=14, fontweight='bold')
ax6.axis('off')

# 7. Combined Attention + Gradient
ax7 = plt.subplot(2, 4, 7)
combined = (attention_resized + saliency) / 2
im7 = ax7.imshow(combined, cmap='hot')
ax7.set_title('Combined Map', fontsize=14, fontweight='bold')
ax7.axis('off')
plt.colorbar(im7, ax=ax7, fraction=0.046)

# 8. Combined Overlay on Original
ax8 = plt.subplot(2, 4, 8)
ax8.imshow(img_array)
ax8.imshow(combined, cmap='hot', alpha=0.6)
ax8.set_title('Combined Overlay', fontsize=14, fontweight='bold')
ax8.axis('off')

plt.suptitle(f'Visual Explanation: Why "{predicted_sneaker}"?',
             fontsize=18, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()

# ===== 5. DETAILED ATTENTION ANALYSIS =====
print("\nüìä Analyzing attention patterns...")

# Find top attention regions (only if mask is not all zeros/nans)
if np.nansum(attention_mask) > 0:
    top_patches = np.argsort(attention_mask)[-5:][::-1]
    print(f"\nTop 5 Most Attended Patches (out of {len(attention_mask)}):")
    for i, patch_idx in enumerate(top_patches, 1):
        row = patch_idx // num_patches
        col = patch_idx % num_patches
        attention_val = attention_mask[patch_idx]
        print(f"{i}. Patch ({row}, {col}): {attention_val:.4f} attention weight")
else:
    print("\nCould not determine top attended patches as attention mask is all zeros or NaNs.")

# ===== 6. TOP-K PREDICTIONS WITH ATTENTION =====
print("\nüéØ Top 3 Predictions:")
top3_probs, top3_indices = torch.topk(probs[0], min(3, len(probs[0])))

for i, (prob, idx) in enumerate(zip(top3_probs, top3_indices), 1):
    sneaker_name = model.config.id2label[idx.item()]
    prob_percent = prob.item() * 100
    bar = "‚ñà" * int(prob_percent / 5)
    marker = "üëü " if i == 1 else "   "
    print(f"{marker}{i}. {sneaker_name:<35} {prob_percent:>6.2f}% {bar}")

# ===== INTERPRETATION GUIDE =====
print("\n" + "="*60)
print("üìñ HOW TO INTERPRET THESE VISUALIZATIONS")
print("="*60)
print(f"""
üî¥ RED/HOT COLORS = High importance/attention
üîµ BLUE/COOL COLORS = Low importance/attention

1. ATTENTION MAP: Shows which image patches the Vision Transformer
   focuses on when making its decision.

2. GRADIENT SALIENCY: Shows which pixels, if changed, would most
   affect the prediction (based on gradients).

3. COMBINED MAP: Merges both attention and gradient information
   for a comprehensive view.

üí° The model is confident about '{predicted_sneaker}' because it's
   focusing on the highlighted regions in these visualizations!
""")

print("="*60)
print("‚úÖ VISUALIZATION COMPLETE!")
print("="*60)
print("\nüí° Run this cell again to analyze another image!")