In [2]:
from data_preprocessing import get_dataloaders
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_SAVE_PATH = "mobilenet_v2_stanford_cars_20classes.pth"
BATCH_SIZE = 32
NUM_CLASSES = 20  # Updated to 20 classes
NUM_EPOCHS = 30
IMG_SIZE = 224  # Image resolution for MobileNetV2

# ==========================================
# 2. MODEL ARCHITECTURE
# ==========================================
def get_mobilenet_model(num_classes=20):
    # Load Pre-trained MobileNetV2
    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)

    # --- STRATEGY: FINE-TUNING ---
    # 1. Freeze the early layers (generic features like lines/edges)
    for param in model.parameters():
        param.requires_grad = False
        
    # 2. Unfreeze the last few inverted residual blocks for better feature learning
    # MobileNetV2 has features organized in a Sequential container
    # Unfreeze the last 4 blocks (out of 19 total blocks)
    for param in model.features[-4:].parameters():
        param.requires_grad = True
        
    # 3. Replace the Classifier Head with Higher Dropout
    in_features = model.classifier[1].in_features  # MobileNetV2 has 1280 features
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, 1024),  # Add an intermediate layer
        nn.ReLU(),
        nn.Dropout(0.5),                # INCREASED dropout
        nn.Linear(1024, num_classes)
    )

    return model

# ==========================================
# 3. VALIDATION FUNCTION
# ==========================================
def validate_model(model, val_loader, criterion, device):
    """
    Evaluates the model on the validation set.
    
    Returns:
        val_loss: Average validation loss
        val_acc: Validation accuracy (%)
    """
    model.eval()  # Set to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():  # Disable gradient computation for validation
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = 100 * correct_predictions / total_samples
    
    return val_loss, val_acc

# ==========================================
# 4. TRAINING WITH VALIDATION
# ==========================================

if __name__ == "__main__":
    
    print("=" * 60)
    print("üöÄ MobileNetV2 Training on 20 Random Classes")
    print("=" * 60)
    
    try:
        # 1. Load Data using HF pipeline (handles downloading, splitting, and transforms)
        print("\nüì¶ Loading data from Hugging Face Hub...")
        train_dl, val_dl, test_dl, selected_classes, label_mapping = get_dataloaders(
            batch_size=BATCH_SIZE, 
            img_size=IMG_SIZE, 
            num_workers=0,
            num_classes=NUM_CLASSES,
            seed=42
        )
        
        print(f"\nüéØ Selected Classes: {selected_classes}")
        print(f"üìä Number of Classes: {len(selected_classes)}")
        
        print("\nü§ñ Initializing MobileNetV2...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"üíª Using device: {device}")
        
        model = get_mobilenet_model(num_classes=NUM_CLASSES).to(device)
        
        # Hyperparameters
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
        
        # Learning Rate Scheduler (now based on validation loss)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.1, patience=3, verbose=True
        )
        
        # Track best validation accuracy for model saving
        best_val_acc = 0.0
        
        print(f"\nüèãÔ∏è Starting Training for {NUM_EPOCHS} epochs...")
        print("=" * 60)
        
        for epoch in range(NUM_EPOCHS):
            # ==========================================
            # TRAINING PHASE
            # ==========================================
            model.train()
            running_loss = 0.0
            correct_predictions = 0 
            total_samples = 0
            
            for images, labels in train_dl:
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                
                # Track Accuracy during training
                _, predicted = torch.max(outputs, 1)
                total_samples += labels.size(0)
                correct_predictions += (predicted == labels).sum().item()

            train_loss = running_loss / len(train_dl)
            train_acc = 100 * correct_predictions / total_samples
            
            # ==========================================
            # VALIDATION PHASE
            # ==========================================
            val_loss, val_acc = validate_model(model, val_dl, criterion, device)
            
            # Print epoch results
            print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}] | "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
                  f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
            
            # Update learning rate based on validation loss
            scheduler.step(val_loss)
            
            # Save best model based on validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                    'selected_classes': selected_classes,
                    'label_mapping': label_mapping
                }, MODEL_SAVE_PATH)
                print(f"  ‚úÖ Best model saved! (Val Acc: {val_acc:.2f}%)")
        
        print("\n" + "=" * 60)
        print("‚úÖ Training completed successfully!")
        print(f"üèÜ Best Validation Accuracy: {best_val_acc:.2f}%")
        print(f"üíæ Model saved as '{MODEL_SAVE_PATH}'")
        print("=" * 60)

        # ==========================================
        # FINAL TEST EVALUATION
        # ==========================================
        print("\nüß™ Evaluating on Test Set...")
        model.load_state_dict(torch.load(MODEL_SAVE_PATH)['model_state_dict'])
        test_loss, test_acc = validate_model(model, test_dl, criterion, device)
        print(f"üìà Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%")

    except Exception as e:
        print("\n‚ùå An error occurred during execution:")
        print(e)
        import traceback
        traceback.print_exc()

üöÄ MobileNetV2 Training on 20 Random Classes

üì¶ Loading data from Hugging Face Hub...
üöÄ Loading 'tanganke/stanford_cars' from Hugging Face Hub...
üìä Total classes in dataset: 196
üéØ Selected 20 random classes: [6, 7, 8, 22, 23, 26, 28, 35, 55, 57, 59, 62, 70, 108, 139, 151, 163, 173, 188, 189]
‚úÖ Filtered dataset size: 829 samples
üìä Total classes in dataset: 196
üéØ Selected 20 random classes: [6, 7, 8, 22, 23, 26, 28, 35, 55, 57, 59, 62, 70, 108, 139, 151, 163, 173, 188, 189]
‚úÖ Filtered dataset size: 820 samples
‚úÖ Data Split: 663 Train | 166 Val | 820 Test
üìå Classes remapped to range: 0-19

üéØ Selected Classes: [6, 7, 8, 22, 23, 26, 28, 35, 55, 57, 59, 62, 70, 108, 139, 151, 163, 173, 188, 189]
üìä Number of Classes: 20

ü§ñ Initializing MobileNetV2...
üíª Using device: cuda

üèãÔ∏è Starting Training for 30 epochs...
Epoch [ 1/30] | Train Loss: 2.9684 | Train Acc: 8.60% | Val Loss: 2.9025 | Val Acc: 12.65%
  ‚úÖ Best model saved! (Val Acc: 12.65%)
Epoch [

  model.load_state_dict(torch.load(MODEL_SAVE_PATH)['model_state_dict'])


üìà Test Loss: 0.8266 | Test Accuracy: 73.29%
