In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm
import cv2
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report, f1_score
from collections import defaultdict
import json
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üî• Using device: {device}")

# ==================== DATASET CLASS ====================

class CloveGradingDataset(Dataset):
    """
    Dataset for clove grading - optimized for ViT
    """
    def __init__(self, data_dir, transform=None, resolution='224x224'):
        self.data_dir = data_dir
        self.transform = transform
        self.resolution = resolution
        self.samples = []
        
        # Collect all samples from grade folders
        grade_mapping = {'Grade 1': 0, 'Grade 2': 1, 'Grade 3': 2, 'Grade 4': 3,
                        'Grade_1': 0, 'Grade_2': 1, 'Grade_3': 2, 'Grade_4': 3}
        
        for grade_folder in os.listdir(data_dir):
            grade_path = os.path.join(data_dir, grade_folder)
            if not os.path.isdir(grade_path):
                continue
            
            # Get label
            label = grade_mapping.get(grade_folder)
            if label is None:
                continue
            
            # Collect images
            for img_file in os.listdir(grade_path):
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    img_path = os.path.join(grade_path, img_file)
                    self.samples.append((img_path, label))
        
        print(f"üìä Loaded {len(self.samples)} samples")
        
        # Print class distribution
        labels = [s[1] for s in self.samples]
        for i in range(4):
            count = labels.count(i)
            print(f"   Grade {i+1}: {count} samples ({count/len(labels)*100:.1f}%)")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply transform
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        
        return image, label

# ==================== VISION TRANSFORMER (VIT) MODEL ====================

class VisionTransformer:
    """Vision Transformer (ViT-B/16) implementation"""
    
    @staticmethod
    def create_model(model_name='vit_base_patch16_224', num_classes=4, pretrained=True, img_size=224):
        """
        Create Vision Transformer model
        
        Args:
            model_name: 'vit_base_patch16_224' (ViT-B/16)
            num_classes: Number of output classes
            pretrained: Use pretrained weights
            img_size: Input image size (224 for ViT-B/16)
        
        Returns:
            model: PyTorch model
        """
        
        # Available ViT models
        vit_models = [
            'vit_base_patch16_224',
            'vit_base_patch16_384',
            'vit_large_patch16_224',
            'vit_large_patch16_384'
        ]
        
        if model_name not in vit_models:
            raise ValueError(f"Model {model_name} not supported. Available: {vit_models}")
        
        print(f"üîß Creating {model_name}...")
        print(f"   Input size: {img_size}x{img_size}")
        print(f"   Patch size: 16x16")
        print(f"   Embedding dimension: 768")
        
        # Create ViT model
        model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes,
            img_size=img_size
        )
        
        # Enhanced head for better classification
        if hasattr(model, 'head'):
            # Original head structure
            in_features = model.head.in_features
            
            # Replace with enhanced classification head
            model.head = nn.Sequential(
                nn.LayerNorm(in_features),
                nn.Dropout(0.5),
                nn.Linear(in_features, 1024),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(512, num_classes)
            )
        
        return model
    
    @staticmethod
    def get_model_info(model_name='vit_base_patch16_224'):
        """Get ViT model parameters and characteristics"""
        model_info = {
            'vit_base_patch16_224': {
                'params': '86.6M',
                'input_size': 224,
                'patch_size': 16,
                'embed_dim': 768,
                'depth': 12,
                'heads': 12,
                'mlp_ratio': 4.0,
                'speed': 'Medium',
                'description': 'Base Vision Transformer with 16x16 patches',
                'strengths': 'Global attention, excellent for structured patterns',
                'attention_type': 'Global self-attention',
                'memory': 'High'
            },
            'vit_base_patch16_384': {
                'params': '86.6M',
                'input_size': 384,
                'patch_size': 16,
                'embed_dim': 768,
                'depth': 12,
                'heads': 12,
                'mlp_ratio': 4.0,
                'speed': 'Slow',
                'description': 'Base ViT with higher resolution',
                'strengths': 'Better fine-grained details',
                'attention_type': 'Global self-attention',
                'memory': 'Very High'
            }
        }
        return model_info.get(model_name, {'params': 'Unknown', 'speed': 'Unknown'})
    
    @staticmethod
    def visualize_attention(model, image_tensor, save_path=None):
        """
        Visualize attention maps from ViT
        """
        model.eval()
        
        # Register hook to get attention weights
        attention_weights = []
        
        def hook_fn(module, input, output):
            attention_weights.append(output[1])  # Attention weights
        
        # Register hooks on all attention blocks
        for block in model.blocks:
            block.attn.register_forward_hook(hook_fn)
        
        # Forward pass
        with torch.no_grad():
            _ = model(image_tensor.unsqueeze(0).to(device))
        
        # Process attention weights
        if attention_weights:
            # Get attention from last layer
            attention = attention_weights[-1].cpu().numpy()
            attention = attention[0]  # Remove batch dimension
            
            # Average over heads
            attention = attention.mean(axis=0)
            
            # Get CLS token attention to patches
            cls_attention = attention[0, 1:]  # Skip CLS token itself
            
            # Reshape to patch grid
            grid_size = int(np.sqrt(cls_attention.shape[0]))
            attention_map = cls_attention.reshape(grid_size, grid_size)
            
            # Upsample to image size
            attention_map = torch.from_numpy(attention_map).unsqueeze(0).unsqueeze(0)
            attention_map = nn.functional.interpolate(
                attention_map, 
                size=(image_tensor.shape[1], image_tensor.shape[2]), 
                mode='bilinear', 
                align_corners=False
            ).squeeze().numpy()
            
            # Normalize
            attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
            
            # Visualize
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
            
            # Original image
            img = image_tensor.permute(1, 2, 0).cpu().numpy()
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)
            
            ax1.imshow(img)
            ax1.set_title('Original Image')
            ax1.axis('off')
            
            # Attention heatmap
            im = ax2.imshow(attention_map, cmap='jet')
            ax2.set_title('ViT Attention Map (CLS token to patches)')
            ax2.axis('off')
            plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
            
            if save_path:
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
            
            plt.tight_layout()
            plt.show()
            
            return attention_map
        
        return None

# ==================== VIT-SPECIFIC TRAINING FRAMEWORK ====================

class ViTTrainer:
    """Specialized trainer for Vision Transformer models"""
    
    def __init__(self, model, device, train_loader, val_loader, test_loader, model_name):
        self.model = model.to(device)
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.model_name = model_name
        
        # ViT-specific hyperparameters
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        
        # AdamW optimizer with ViT-specific settings
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=3e-4,  # Higher LR for transformers
            betas=(0.9, 0.999),
            weight_decay=0.05,  # Higher weight decay
            eps=1e-8
        )
        
        # Cosine annealing with warmup (critical for ViT)
        self.warmup_epochs = 5
        self.total_epochs = 50
        
        # Create custom scheduler
        self.scheduler = self._create_scheduler()
        
        # Gradient scaling for stability
        self.scaler = torch.cuda.amp.GradScaler()
        
        # Training history
        self.history = {
            'train_loss': [], 'train_acc': [], 'train_f1': [],
            'val_loss': [], 'val_acc': [], 'val_f1': [],
            'test_acc': None, 'test_f1': None,
            'learning_rates': [],
            'gradient_norms': [],
            'attention_maps': []
        }
    
    def _create_scheduler(self):
        """Create learning rate scheduler with warmup"""
        # Linear warmup followed by cosine decay
        def lr_lambda(epoch):
            if epoch < self.warmup_epochs:
                # Linear warmup
                return (epoch + 1) / self.warmup_epochs
            else:
                # Cosine decay
                progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
                return 0.5 * (1 + np.cos(np.pi * progress))
        
        return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
    
    def train_epoch(self, epoch):
        """Train for one epoch with ViT-specific optimizations"""
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        gradient_norms = []
        
        for batch_idx, (images, labels) in enumerate(tqdm(self.train_loader, desc=f"Training {self.model_name}")):
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Mixed precision training
            with torch.cuda.amp.autocast():
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
            
            # Scale loss and backward
            self.scaler.scale(loss).backward()
            
            # Unscale for gradient clipping
            self.scaler.unscale_(self.optimizer)
            
            # Gradient clipping (important for transformers)
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            gradient_norms.append(grad_norm.item())
            
            # Step optimizer
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        # Step scheduler
        self.scheduler.step()
        
        avg_loss = total_loss / len(self.train_loader)
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='weighted')
        avg_grad_norm = np.mean(gradient_norms)
        
        return avg_loss, accuracy, f1, avg_grad_norm
    
    def validate(self, loader):
        """Validate on given loader"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                with torch.cuda.amp.autocast():
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels)
                
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        avg_loss = total_loss / len(loader)
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='weighted')
        
        return avg_loss, accuracy, f1, all_preds, all_labels
    
    def train(self, num_epochs=50, save_path=None):
        """Complete training loop for ViT"""
        print(f"\nüöÄ Training {self.model_name} - Vision Transformer")
        print("=" * 80)
        
        # Print model info
        model_info = VisionTransformer.get_model_info(self.model_name)
        print(f"üìä Model: {self.model_name}")
        print(f"   Parameters: {model_info['params']}")
        print(f"   Input Size: {model_info['input_size']}x{model_info['input_size']}")
        print(f"   Patch Size: {model_info['patch_size']}x{model_info['patch_size']}")
        print(f"   Depth: {model_info['depth']} transformer blocks")
        print(f"   Attention Heads: {model_info['heads']}")
        print(f"   Description: {model_info['description']}")
        print(f"   Strengths: {model_info['strengths']}")
        
        print(f"\n‚ö° Training Configuration:")
        print(f"   Warmup epochs: {self.warmup_epochs}")
        print(f"   Total epochs: {num_epochs}")
        print(f"   Learning rate: {self.optimizer.param_groups[0]['lr']}")
        print(f"   Weight decay: {self.optimizer.param_groups[0]['weight_decay']}")
        
        best_val_f1 = 0
        patience_counter = 0
        max_patience = 15
        
        for epoch in range(num_epochs):
            # Train
            train_loss, train_acc, train_f1, grad_norm = self.train_epoch(epoch)
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['train_f1'].append(train_f1)
            self.history['gradient_norms'].append(grad_norm)
            
            # Validate
            val_loss, val_acc, val_f1, _, _ = self.validate(self.val_loader)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['val_f1'].append(val_f1)
            self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])
            
            # Check if in warmup phase
            phase = "WARMUP" if epoch < self.warmup_epochs else "TRAINING"
            
            print(f"\nüìà Epoch {epoch+1}/{num_epochs} [{phase}]:")
            print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
            print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
            print(f"  Grad Norm: {grad_norm:.2f}, LR: {self.optimizer.param_groups[0]['lr']:.2e}")
            
            # Save best model
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                if save_path:
                    torch.save({
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'scheduler_state_dict': self.scheduler.state_dict(),
                        'scaler_state_dict': self.scaler.state_dict(),
                        'epoch': epoch,
                        'val_f1': val_f1,
                        'history': self.history
                    }, save_path)
                    print(f"  üíæ New best model saved! F1: {val_f1:.4f}")
            else:
                patience_counter += 1
            
            if patience_counter >= max_patience:
                print(f"‚èπÔ∏è Early stopping triggered")
                break
            
            # Visualize attention for first epoch and best epoch
            if epoch == 0 or epoch == num_epochs - 1 or val_f1 == best_val_f1:
                if len(self.train_loader) > 0:
                    # Get a sample image
                    sample_image, _ = next(iter(self.train_loader))
                    if len(sample_image) > 0:
                        attention_map = VisionTransformer.visualize_attention(
                            self.model, 
                            sample_image[0].cpu(),
                            save_path=f'/kaggle/working/attention_epoch_{epoch+1}.png' if save_path else None
                        )
                        if attention_map is not None:
                            self.history['attention_maps'].append({
                                'epoch': epoch,
                                'val_f1': val_f1,
                                'attention_map': attention_map.tolist()  # Convert to list for JSON serialization
                            })
        
        # Final test evaluation
        print(f"\nüìä Final Test Evaluation:")
        test_loss, test_acc, test_f1, test_preds, test_labels = self.validate(self.test_loader)
        self.history['test_acc'] = test_acc
        self.history['test_f1'] = test_f1
        
        print(f"Test - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}")
        
        # Detailed analysis
        self._detailed_analysis(test_preds, test_labels)
        
        # Plot training history
        self.plot_training_history()
        
        # Visualize final attention
        if len(self.train_loader) > 0:
            sample_image, _ = next(iter(self.train_loader))
            if len(sample_image) > 0:
                VisionTransformer.visualize_attention(
                    self.model, 
                    sample_image[0].cpu(),
                    save_path='/kaggle/working/attention_final.png'
                )
        
        return self.history
    
    def _detailed_analysis(self, preds, labels):
        """Perform detailed analysis of results"""
        # Classification report
        print("\nüìã Detailed Classification Report:")
        print(classification_report(labels, preds, target_names=['Grade 1', 'Grade 2', 'Grade 3', 'Grade 4']))
        
        # Confusion matrix
        cm = confusion_matrix(labels, preds)
        self.plot_confusion_matrix(cm)
        
        # Per-class metrics
        print("\nüéØ Per-Class Performance:")
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
        for i, class_name in enumerate(['Grade 1', 'Grade 2', 'Grade 3', 'Grade 4']):
            print(f"   {class_name}: Precision={precision[i]:.3f}, Recall={recall[i]:.3f}, F1={f1[i]:.3f}")
        
        # Transformer-specific metrics
        print(f"\nüîç Transformer Insights:")
        print(f"   Number of misclassified samples: {np.sum(np.array(preds) != np.array(labels))}")
        print(f"   Most confused pair: {self._get_most_confused_pair(preds, labels)}")
    
    def _get_most_confused_pair(self, preds, labels):
        """Find most confused class pair"""
        cm = confusion_matrix(labels, preds)
        np.fill_diagonal(cm, 0)  # Ignore correct predictions
        
        if cm.sum() > 0:
            max_idx = np.unravel_index(cm.argmax(), cm.shape)
            return f"Class {max_idx[0]+1} ‚Üí Class {max_idx[1]+1} ({cm[max_idx]} samples)"
        return "No misclassifications"
    
    def plot_confusion_matrix(self, cm):
        """Plot confusion matrix"""
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                   xticklabels=['Grade 1', 'Grade 2', 'Grade 3', 'Grade 4'],
                   yticklabels=['Grade 1', 'Grade 2', 'Grade 3', 'Grade 4'])
        plt.title(f'Confusion Matrix - {self.model_name}')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(f'/kaggle/working/confusion_matrix_{self.model_name}.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_training_history(self):
        """Plot comprehensive training history"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        epochs = range(1, len(self.history['train_loss']) + 1)
        
        # Loss
        ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2, alpha=0.8)
        ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2, alpha=0.8)
        ax1.axvline(self.warmup_epochs, color='gray', linestyle='--', alpha=0.5, label='Warmup End')
        ax1.set_title(f'ViT Training History - Loss', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Accuracy
        ax2.plot(epochs, self.history['train_acc'], 'b-', label='Train Accuracy', linewidth=2, alpha=0.8)
        ax2.plot(epochs, self.history['val_acc'], 'r-', label='Val Accuracy', linewidth=2, alpha=0.8)
        ax2.axvline(self.warmup_epochs, color='gray', linestyle='--', alpha=0.5, label='Warmup End')
        ax2.set_title(f'ViT Training History - Accuracy', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # F1 Score
        ax3.plot(epochs, self.history['train_f1'], 'b-', label='Train F1', linewidth=2, alpha=0.8)
        ax3.plot(epochs, self.history['val_f1'], 'r-', label='Val F1', linewidth=2, alpha=0.8)
        ax3.axvline(self.warmup_epochs, color='gray', linestyle='--', alpha=0.5, label='Warmup End')
        ax3.set_title(f'ViT Training History - F1 Score', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('F1 Score')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Learning rate schedule
        ax4.plot(epochs, self.history['learning_rates'], 'g-', linewidth=2, alpha=0.8)
        ax4.axvline(self.warmup_epochs, color='gray', linestyle='--', alpha=0.5, label='Warmup End')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Learning Rate')
        ax4.set_title(f'ViT Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax4.set_yscale('log')
        ax4.grid(True, alpha=0.3)
        
        # Add warmup annotation
        ax4.annotate('Warmup Phase', 
                    xy=(self.warmup_epochs/2, self.history['learning_rates'][self.warmup_epochs//2]),
                    xytext=(self.warmup_epochs/2, self.history['learning_rates'][self.warmup_epochs//2] * 5),
                    arrowprops=dict(arrowstyle='->', color='black'),
                    fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(f'/kaggle/working/vit_training_history_{self.model_name}.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()

# ==================== VIT EXPERIMENT MANAGER ====================

class ViTExperimentManager:
    """Manage Vision Transformer experiments"""
    
    def __init__(self, data_dir, resolution='224x224', batch_size=16):
        self.data_dir = data_dir
        self.resolution = resolution
        self.batch_size = batch_size
        self.results = []
    
    def prepare_dataloaders(self):
        """Prepare train/val/test dataloaders for ViT"""
        
        # ViT requires specific preprocessing
        img_size = 224  # ViT-B/16 standard size
        
        # ViT-optimized transforms
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # ViT normalization
        ])
        
        val_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
        # Create full dataset
        full_dataset = CloveGradingDataset(
            self.data_dir, 
            transform=None,
            resolution=self.resolution
        )
        
        # Split dataset
        train_size = int(0.7 * len(full_dataset))
        val_size = int(0.15 * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        # Apply transforms
        train_dataset.dataset.transform = transform
        val_dataset.dataset.transform = val_transform
        test_dataset.dataset.transform = val_transform
        
        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)
        
        print(f"‚úÖ ViT Data Preparation Complete:")
        print(f"   Input size: {img_size}x{img_size}")
        print(f"   Batch size: {self.batch_size}")
        print(f"   Train: {train_size}, Val: {val_size}, Test: {test_size}")
        print(f"   Normalization: mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]")
        
        return train_loader, val_loader, test_loader
    
    def run_vit_experiment(self, model_name='vit_base_patch16_224'):
        """
        Run Vision Transformer experiment
        
        Args:
            model_name: ViT model name
        """
        print(f"\n{'='*80}")
        print(f"üöÄ VISION TRANSFORMER EXPERIMENT: {model_name.upper()}")
        print(f"{'='*80}")
        
        # Prepare data
        train_loader, val_loader, test_loader = self.prepare_dataloaders()
        
        # Create model
        model = VisionTransformer.create_model(
            model_name=model_name,
            num_classes=4,
            pretrained=True,
            img_size=224
        )
        
        # Train model
        trainer = ViTTrainer(
            model, device, train_loader, val_loader, test_loader,
            model_name=model_name
        )
        
        save_path = f"/kaggle/working/best_{model_name}.pth"
        history = trainer.train(num_epochs=50, save_path=save_path)
        
        # Store results
        result = {
            'model_name': model_name,
            'resolution': self.resolution,
            'test_accuracy': history['test_acc'],
            'test_f1': history['test_f1'],
            'best_val_f1': max(history['val_f1']),
            'final_train_acc': history['train_acc'][-1],
            'final_val_acc': history['val_acc'][-1],
            'train_val_gap': history['train_acc'][-1] - history['val_acc'][-1],
            'model_info': VisionTransformer.get_model_info(model_name),
            'attention_maps': history.get('attention_maps', []),
            'history': history
        }
        
        self.results.append(result)
        
        return result
    
    def compare_results(self):
        """Display ViT results"""
        if not self.results:
            print("No results to compare!")
            return
        
        # Create results DataFrame
        df_data = []
        for r in self.results:
            df_data.append({
                'Model': r['model_name'],
                'Parameters': r['model_info']['params'],
                'Input Size': f"{r['model_info']['input_size']}x{r['model_info']['input_size']}",
                'Test Accuracy': f"{r['test_accuracy']:.4f}",
                'Test F1': f"{r['test_f1']:.4f}",
                'Best Val F1': f"{r['best_val_f1']:.4f}",
                'Train-Val Gap': f"{r['train_val_gap']:.4f}",
                'Attention Heads': r['model_info']['heads'],
                'Transformer Depth': r['model_info']['depth']
            })
        
        df = pd.DataFrame(df_data)
        
        print("\nüìä VISION TRANSFORMER RESULTS")
        print("=" * 100)
        print(df.to_string(index=False))
        
        # Save to CSV
        df.to_csv('/kaggle/working/vit_results.csv', index=False)
        print("\n‚úÖ Results saved to 'vit_results.csv'")
        
        # Plot results
        self.plot_results()
        
        # Save detailed results
        with open('/kaggle/working/vit_detailed_results.json', 'w') as f:
            # Convert numpy arrays to lists for JSON serialization
            import copy
            results_copy = copy.deepcopy(self.results)
            for r in results_copy:
                if 'history' in r:
                    # Remove large arrays if needed
                    pass
            json.dump(results_copy, f, indent=2, default=str)
    
    def plot_results(self):
        """Plot ViT results"""
        if not self.results:
            return
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        model_name = self.results[0]['model_name']
        history = self.results[0]['history']
        
        epochs = range(1, len(history['train_loss']) + 1)
        
        # Training curves
        ax1.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
        ax1.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title(f'{model_name} - Training & Validation Loss', fontsize=14, fontweight='bold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        ax2.plot(epochs, history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
        ax2.plot(epochs, history['val_acc'], 'r-', label='Val Accuracy', linewidth=2)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title(f'{model_name} - Training & Validation Accuracy', fontsize=14, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Learning rate
        ax3.plot(epochs, history['learning_rates'], 'g-', linewidth=2)
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Learning Rate')
        ax3.set_title(f'{model_name} - Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax3.set_yscale('log')
        ax3.grid(True, alpha=0.3)
        
        # Model architecture info
        ax4.axis('off')
        model_info = self.results[0]['model_info']
        info_text = f"""
        üèóÔ∏è Model Architecture:
        
        ‚Ä¢ Model: {model_name}
        ‚Ä¢ Parameters: {model_info['params']}
        ‚Ä¢ Input Size: {model_info['input_size']}x{model_info['input_size']}
        ‚Ä¢ Patch Size: {model_info['patch_size']}x{model_info['patch_size']}
        ‚Ä¢ Embedding Dim: {model_info['embed_dim']}
        ‚Ä¢ Depth: {model_info['depth']} blocks
        ‚Ä¢ Attention Heads: {model_info['heads']}
        ‚Ä¢ MLP Ratio: {model_info['mlp_ratio']}
        
        üéØ Final Results:
        
        ‚Ä¢ Test Accuracy: {self.results[0]['test_accuracy']:.4f}
        ‚Ä¢ Test F1: {self.results[0]['test_f1']:.4f}
        ‚Ä¢ Best Val F1: {self.results[0]['best_val_f1']:.4f}
        """
        
        ax4.text(0.1, 0.9, info_text, transform=ax4.transAxes, fontsize=12,
                verticalalignment='top', bbox=dict(boxstyle="round,pad=1.0", facecolor="lightblue"))
        
        plt.tight_layout()
        plt.savefig(f'/kaggle/working/vit_results_summary_{model_name}.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()

# ==================== MAIN EXECUTION ====================

def run_vit_analysis():
    """
    MAIN FUNCTION: Run Vision Transformer analysis
    """
    print("\nüöÄ VISION TRANSFORMER (ViT-B/16) ANALYSIS")
    print("=" * 80)
    
    # Show system info
    print(f"System: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    # Initialize manager
    manager = ViTExperimentManager(
        data_dir="/kaggle/input/processed-images-224x224",
        resolution='224x224',
        batch_size=16
    )
    
    # ViT model to test
    vit_model = 'vit_base_patch16_224'
    
    print(f"\nüéØ TESTING VISION TRANSFORMER:")
    model_info = VisionTransformer.get_model_info(vit_model)
    print(f"   Model: {vit_model}")
    print(f"   Parameters: {model_info['params']}")
    print(f"   Architecture: {model_info['description']}")
    print(f"   Strengths: {model_info['strengths']}")
    
    print(f"\n‚ö° ViT-SPECIFIC FEATURES:")
    print(f"   ‚Ä¢ Global self-attention mechanism")
    print(f"   ‚Ä¢ 16x16 patch embedding")
    print(f"   ‚Ä¢ {model_info['depth']} transformer blocks")
    print(f"   ‚Ä¢ {model_info['heads']} attention heads")
    print(f"   ‚Ä¢ Cosine learning rate schedule with warmup")
    
    print(f"\n‚è±Ô∏è TIME ESTIMATE:")
    print(f"   ‚Ä¢ Total training: ~60-80 minutes")
    print(f"   ‚Ä¢ Epochs: 50")
    print(f"   ‚Ä¢ Warmup epochs: 5")
    
    print(f"\nüîç ATTENTION VISUALIZATION:")
    print(f"   ‚Ä¢ Attention maps will be saved during training")
    print(f"   ‚Ä¢ Shows what parts of image the model focuses on")
    
    # Run experiment
    result = manager.run_vit_experiment(vit_model)
    
    # Display results
    print(f"\n{'='*80}")
    print("VISION TRANSFORMER RESULTS SUMMARY")
    print(f"{'='*80}")
    
    print(f"\nüèÜ FINAL PERFORMANCE:")
    print(f"   Test Accuracy: {result['test_accuracy']:.4f}")
    print(f"   Test F1 Score: {result['test_f1']:.4f}")
    print(f"   Best Validation F1: {result['best_val_f1']:.4f}")
    
    print(f"\nüìà TRAINING INSIGHTS:")
    print(f"   Train-Val Accuracy Gap: {result['train_val_gap']:.4f}")
    if result['train_val_gap'] > 0.15:
        print(f"   ‚ö†Ô∏è  Significant overfitting detected")
    elif result['train_val_gap'] < 0.05:
        print(f"   ‚úÖ Excellent generalization")
    
    print(f"\nüîç TRANSFORMER-SPECIFIC ANALYSIS:")
    print(f"   ‚Ä¢ Attention maps saved to /kaggle/working/")
    print(f"   ‚Ä¢ Model checkpoints saved")
    print(f"   ‚Ä¢ Training history plots generated")
    
    manager.compare_results()
    
    print(f"\n‚úÖ VISION TRANSFORMER ANALYSIS COMPLETE!")
    print(f"üìÅ Output files:")
    print(f"   - vit_results.csv (performance metrics)")
    print(f"   - vit_results_summary_*.png (visual summary)")
    print(f"   - vit_training_history_*.png (training curves)")
    print(f"   - attention_*.png (attention visualization)")
    print(f"   - confusion_matrix_*.png")
    print(f"   - best_*.pth (trained model)")
    
    return manager

def quick_vit_test():
    """
    Quick test with fewer epochs
    """
    print("\n‚ö° QUICK ViT TEST (30 epochs)")
    print("=" * 60)
    
    manager = ViTExperimentManager(
        data_dir="/kaggle/input/processed-images-224x224",
        resolution='224x224',
        batch_size=16
    )
    
    # Prepare data
    train_loader, val_loader, test_loader = manager.prepare_dataloaders()
    
    # Create model
    model = VisionTransformer.create_model(
        model_name='vit_base_patch16_224',
        num_classes=4,
        pretrained=True,
        img_size=224
    )
    
    # Quick trainer with fewer epochs
    class QuickViTTrainer(ViTTrainer):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.warmup_epochs = 3
            self.total_epochs = 30
    
    trainer = QuickViTTrainer(
        model, device, train_loader, val_loader, test_loader,
        model_name='vit_base_patch16_224'
    )
    
    history = trainer.train(num_epochs=30, save_path='/kaggle/working/best_vit_quick.pth')
    
    print(f"\n‚úÖ QUICK TEST COMPLETE!")
    print(f"   Model: ViT-B/16")
    print(f"   Test F1: {history['test_f1']:.4f}")
    print(f"   Test Accuracy: {history['test_acc']:.4f}")
    
    return manager

# ==================== EXECUTION ====================

if __name__ == "__main__":
    """
    EXECUTE VISION TRANSFORMER ANALYSIS
    """
    print("\n" + "="*80)
    print("üéØ VISION TRANSFORMER (ViT-B/16) FOR CLOVE GRADING")
    print("="*80)
    
    print("\nüîß EXECUTION OPTIONS:")
    print("1. Full ViT analysis (50 epochs) - ~60-80 minutes")
    print("2. Quick ViT test (30 epochs) - ~35-45 minutes")
    
    # Uncomment your preferred option:
    
    # Option 1: Full analysis (recommended)
    manager = run_vit_analysis()
    
    # Option 2: Quick test
    # manager = quick_vit_test()

print("\n‚úÖ VISION TRANSFORMER CODE READY!")
print("üöÄ Run the cell to start ViT-B/16 training with attention visualization")