In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from data_preprocessing import get_dataloaders

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_SAVE_PATH = "inception_v1_stanford_cars_10classes.pth"
BATCH_SIZE = 32
NUM_CLASSES = 10  # Updated to 20 classes
NUM_EPOCHS = 30
IMG_SIZE = 299  # Inception requires 299x299 input

# ==========================================
# 2. MODEL ARCHITECTURE
# ==========================================
def get_inception_model(num_classes=10):
    # Load Pre-trained Inception V1 (GoogLeNet) with auxiliary classifiers enabled
    model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1, aux_logits=True)

    # --- STRATEGY: FINE-TUNING ---
    # 1. Freeze the early layers (generic features like lines/edges)
    for param in model.parameters():
        param.requires_grad = False
        
    # 2. Unfreeze the last inception modules for better feature learning
    # GoogLeNet has inception4 and inception5 modules
    for param in model.inception4e.parameters():
        param.requires_grad = True
    for param in model.inception5a.parameters():
        param.requires_grad = True
    for param in model.inception5b.parameters():
        param.requires_grad = True
        
    # 3. Replace the Final Classifier Head
    in_features = model.fc.in_features  # GoogLeNet has 1024 features
    model.fc = nn.Sequential(
        nn.Linear(in_features, 1024),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(1024, num_classes)
    )
    
    # Replace auxiliary classifiers to match our number of classes
    model.aux1.fc2 = nn.Linear(model.aux1.fc2.in_features, num_classes)
    model.aux2.fc2 = nn.Linear(model.aux2.fc2.in_features, num_classes)

    return model

# ==========================================
# 3. VALIDATION FUNCTION
# ==========================================
def validate_model(model, val_loader, criterion, device):
    """
    Evaluates the model on the validation set.
    
    Returns:
        val_loss: Average validation loss
        val_acc: Validation accuracy (%)
    """
    model.eval()  # Set to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():  # Disable gradient computation for validation
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            
            # Handle potential tuple output from GoogLeNet
            if isinstance(outputs, tuple):
                outputs = outputs[0]  # Use only the main output
            
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = 100 * correct_predictions / total_samples
    
    return val_loss, val_acc

# ==========================================
# 4. TRAINING WITH VALIDATION
# ==========================================

if __name__ == "__main__":
    
    print("=" * 60)
    print("üöÄ Inception V1 (GoogLeNet) Training on 20 Random Classes")
    print("=" * 60)
    
    try:
        # 1. Load Data using HF pipeline (handles downloading, splitting, and transforms)
        print("\nüì¶ Loading data from Hugging Face Hub...")
        train_dl, val_dl, test_dl, selected_classes, label_mapping = get_dataloaders(
            batch_size=BATCH_SIZE, 
            img_size=IMG_SIZE,  # Inception requires 299x299
            num_workers=0,
            num_classes=NUM_CLASSES,
            seed=42
        )
        
        print(f"\nüéØ Selected Classes: {selected_classes}")
        print(f"üìä Number of Classes: {len(selected_classes)}")
        
        print("\nü§ñ Initializing Inception V1 (GoogLeNet)...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"üíª Using device: {device}")
        
        model = get_inception_model(num_classes=NUM_CLASSES).to(device)
        
        # Hyperparameters
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
        
        # Learning Rate Scheduler (based on validation loss)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.1, patience=3, verbose=True
        )
        
        # Track best validation accuracy for model saving
        best_val_acc = 0.0
        
        print(f"\nüèãÔ∏è Starting Training for {NUM_EPOCHS} epochs...")
        print("=" * 60)
        
        for epoch in range(NUM_EPOCHS):
            # ==========================================
            # TRAINING PHASE
            # ==========================================
            model.train()
            train_loss = 0.0
            train_correct = 0 
            train_total = 0
            
            for images, labels in train_dl:
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                # GoogLeNet returns (outputs, aux2, aux1) during training when aux_logits=True
                outputs = model(images)
                
                if isinstance(outputs, tuple):
                    # Training mode: unpack main output and auxiliary outputs
                    main_output, aux2_output, aux1_output = outputs
                    loss1 = criterion(main_output, labels)
                    loss2 = criterion(aux2_output, labels)
                    loss3 = criterion(aux1_output, labels)
                    # Combined loss: main output weighted more heavily
                    loss = loss1 + 0.3 * loss2 + 0.3 * loss3
                    outputs = main_output  # Use main output for accuracy calculation
                else:
                    # Inference mode: single output
                    loss = criterion(outputs, labels)
                
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                
                # Track Training Accuracy
                _, predicted = torch.max(outputs, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

            train_epoch_loss = train_loss / len(train_dl)
            train_epoch_acc = 100 * train_correct / train_total
            
            # ==========================================
            # VALIDATION PHASE
            # ==========================================
            val_loss, val_acc = validate_model(model, val_dl, criterion, device)
            
            # Print epoch results
            print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}] | "
                  f"Train Loss: {train_epoch_loss:.4f} | Train Acc: {train_epoch_acc:.2f}% | "
                  f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
            
            # Update learning rate based on validation loss
            scheduler.step(val_loss)
            
            # Save best model based on validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                    'selected_classes': selected_classes,
                    'label_mapping': label_mapping
                }, MODEL_SAVE_PATH)
                print(f"  ‚úÖ Best model saved! (Val Acc: {val_acc:.2f}%)")
        
        print("\n" + "=" * 60)
        print("‚úÖ Training completed successfully!")
        print(f"üèÜ Best Validation Accuracy: {best_val_acc:.2f}%")
        print(f"üíæ Model saved as '{MODEL_SAVE_PATH}'")
        print("=" * 60)

        # ==========================================
        # FINAL TEST EVALUATION
        # ==========================================
        print("\nüß™ Evaluating on Test Set...")
        checkpoint = torch.load(MODEL_SAVE_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        test_loss, test_acc = validate_model(model, test_dl, criterion, device)
        print(f"üìà Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%")

    except Exception as e:
        print("\n‚ùå An error occurred during execution:")
        print(e)
        import traceback
        traceback.print_exc()

  from .autonotebook import tqdm as notebook_tqdm


üöÄ Inception V1 (GoogLeNet) Training on 20 Random Classes

üì¶ Loading data from Hugging Face Hub...
üöÄ Loading 'tanganke/stanford_cars' from Hugging Face Hub...
üìä Total classes in dataset: 196
üéØ Selected 10 random classes: [6, 26, 28, 35, 57, 62, 70, 163, 188, 189]


Filter: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8144/8144 [00:44<00:00, 183.94 examples/s]


‚úÖ Filtered dataset size: 410 samples
üìä Total classes in dataset: 196
üéØ Selected 10 random classes: [6, 26, 28, 35, 57, 62, 70, 163, 188, 189]


Filter: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8041/8041 [00:35<00:00, 229.38 examples/s]


‚úÖ Filtered dataset size: 406 samples


Filter: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8041/8041 [00:36<00:00, 220.86 examples/s]


‚úÖ Data Split: 328 Train | 82 Val | 406 Test
üìå Classes remapped to range: 0-9

üéØ Selected Classes: [6, 26, 28, 35, 57, 62, 70, 163, 188, 189]
üìä Number of Classes: 10

ü§ñ Initializing Inception V1 (GoogLeNet)...
üíª Using device: cuda





üèãÔ∏è Starting Training for 30 epochs...
Epoch [ 1/30] | Train Loss: 3.7806 | Train Acc: 11.89% | Val Loss: 2.2804 | Val Acc: 10.98%
  ‚úÖ Best model saved! (Val Acc: 10.98%)
Epoch [ 2/30] | Train Loss: 3.6778 | Train Acc: 22.87% | Val Loss: 2.2573 | Val Acc: 13.41%
  ‚úÖ Best model saved! (Val Acc: 13.41%)
Epoch [ 3/30] | Train Loss: 3.6267 | Train Acc: 27.44% | Val Loss: 2.1970 | Val Acc: 26.83%
  ‚úÖ Best model saved! (Val Acc: 26.83%)
Epoch [ 4/30] | Train Loss: 3.5339 | Train Acc: 38.41% | Val Loss: 2.1050 | Val Acc: 32.93%
  ‚úÖ Best model saved! (Val Acc: 32.93%)
Epoch [ 5/30] | Train Loss: 3.3926 | Train Acc: 49.39% | Val Loss: 1.9625 | Val Acc: 47.56%
  ‚úÖ Best model saved! (Val Acc: 47.56%)
Epoch [ 6/30] | Train Loss: 3.2837 | Train Acc: 59.15% | Val Loss: 1.8057 | Val Acc: 62.20%
  ‚úÖ Best model saved! (Val Acc: 62.20%)
Epoch [ 7/30] | Train Loss: 3.1365 | Train Acc: 62.50% | Val Loss: 1.6283 | Val Acc: 64.63%
  ‚úÖ Best model saved! (Val Acc: 64.63%)
Epoch [ 8/30] | Tr

  checkpoint = torch.load(MODEL_SAVE_PATH)


üìà Test Loss: 0.3582 | Test Accuracy: 88.18%
