In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import time
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from tqdm.notebook import tqdm  # For progress bars in notebook

In [2]:
# Check if MPS (Apple Silicon GPU) is available - thorough checking
def get_device():
    if not torch.backends.mps.is_available():
        print("MPS not available - checking why...")
        if not torch.backends.mps.is_built():
            print("PyTorch not compiled with MPS support. Verify your PyTorch version (needs 1.12+ and proper installation)")
        else:
            print("PyTorch has MPS support but MPS is not available on this device")
        return torch.device("cpu")
    
    # MPS is available, but verify we can actually create a tensor on it
    try:
        # Attempt to create a small tensor on MPS
        test_tensor = torch.zeros(1, device="mps")
        print(f"Test tensor created on MPS successfully: {test_tensor.device}")
        print("MPS is working properly")
        return torch.device("mps")
    except Exception as e:
        print(f"Error initializing MPS: {e}")
        print("Falling back to CPU")
        return torch.device("cpu")

# Get device for training
device = get_device()
print(f"Using device: {device}")

Test tensor created on MPS successfully: mps:0
MPS is working properly
Using device: mps


In [3]:
# ## Hyperparameters
# You can modify these to experiment with different settings

# Batch size - larger values are generally more efficient on the GPU
batch_size = 128  # Try 64, 128, 256 for efficiency testing

# Training parameters
learning_rate = 0.001
num_epochs = 15
num_classes = 10

In [4]:
## Data Loading and Preprocessing

# Standard transforms for training
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # EfficientNet expects 224x224 images
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Transforms for validation/testing (no augmentation)
eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
# Load datasets
train_dataset = ImageFolder(root='data/train', transform=train_transform)
valid_dataset = ImageFolder(root='data/valid', transform=eval_transform)
test_dataset = ImageFolder(root='data/test', transform=eval_transform)

# Create data loaders - pin_memory=True improves GPU transfer speed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                         num_workers=4, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, 
                         num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                        num_workers=4, pin_memory=True)

# Print dataset info
class_names = train_dataset.classes
print(f"Classes: {class_names}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Training samples: 90000
Validation samples: 90000
Test samples: 90000


In [6]:
# ## Model Architecture - EfficientNet B0
# We'll use a pre-trained EfficientNet B0 model and adapt it to our dataset

def create_efficientnet_model(num_classes=10, pretrained=True):
    """Create and return an EfficientNet B0 model adapted to our dataset"""
    if pretrained:
        # Load pre-trained weights
        weights = EfficientNet_B0_Weights.DEFAULT
        model = efficientnet_b0(weights=weights)
        print("Loaded EfficientNet B0 with pre-trained weights")
    else:
        model = efficientnet_b0(weights=None)
        print("Initialized EfficientNet B0 with random weights")
    
    # Replace the classifier (final layer)
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.2, inplace=True),
        nn.Linear(in_features=in_features, out_features=num_classes),
    )
    
    return model

# Create the model and move it to the device
model = create_efficientnet_model(num_classes=num_classes, pretrained=True)
model = model.to(device)

# ## Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

# Learning rate scheduler for better convergence
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

Loaded EfficientNet B0 with pre-trained weights


In [7]:
# ## Training and Validation Functions

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Use tqdm for progress bar
    with tqdm(train_loader, desc="Training", leave=False) as t:
        for images, labels in t:
            # Move to device
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            t.set_postfix(loss=loss.item(), acc=100.*correct/total)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

In [8]:
def validate(model, valid_loader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        with tqdm(valid_loader, desc="Validation", leave=False) as t:
            for images, labels in t:
                images, labels = images.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Statistics
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                # Update progress bar
                t.set_postfix(loss=loss.item(), acc=100.*correct/total)
    
    epoch_loss = running_loss / len(valid_loader.dataset)
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

In [9]:
# ## Training Loop
print("Starting training...")
start_time = time.time()

# Verify device placement
sample_inputs, _ = next(iter(train_loader))
sample_inputs = sample_inputs.to(device)
print(f"Input batch device: {sample_inputs.device}")

model_device = next(model.parameters()).device
print(f"Model parameters device: {model_device}")

# Lists to store metrics
train_losses = []
valid_losses = []
train_accuracies = []
valid_accuracies = []
lrs = []

Starting training...


Input batch device: mps:0
Model parameters device: mps:0


In [None]:
# Loop over epochs
for epoch in range(num_epochs):
    epoch_start = time.time()
    
    # Train and validate
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    valid_loss, valid_acc = validate(model, valid_loader, criterion, device)
    
    # Learning rate scheduler step
    scheduler.step(valid_loss)
    current_lr = optimizer.param_groups[0]['lr']
    lrs.append(current_lr)
    
    # Store metrics
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    train_accuracies.append(train_acc)
    valid_accuracies.append(valid_acc)
    
    epoch_time = time.time() - epoch_start
    
    # Print statistics
    print(f"Epoch {epoch+1}/{num_epochs} | Time: {epoch_time:.2f}s | LR: {current_lr:.6f}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.2f}%")
    print("-" * 50)

total_time = time.time() - start_time
print(f"Training completed in {total_time/60:.2f} minutes")

Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 1/15 | Time: 606.43s | LR: 0.001000
Train Loss: 0.5713 | Train Acc: 79.92%
Valid Loss: 0.4158 | Valid Acc: 85.24%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 2/15 | Time: 593.97s | LR: 0.001000
Train Loss: 0.3965 | Train Acc: 85.96%
Valid Loss: 0.3771 | Valid Acc: 86.67%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 3/15 | Time: 589.66s | LR: 0.001000
Train Loss: 0.3287 | Train Acc: 88.43%
Valid Loss: 0.3595 | Valid Acc: 87.52%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 4/15 | Time: 602.21s | LR: 0.001000
Train Loss: 0.2881 | Train Acc: 89.78%
Valid Loss: 0.3631 | Valid Acc: 87.95%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 5/15 | Time: 605.31s | LR: 0.001000
Train Loss: 0.2567 | Train Acc: 90.88%
Valid Loss: 0.3561 | Valid Acc: 88.19%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 6/15 | Time: 609.67s | LR: 0.001000
Train Loss: 0.2252 | Train Acc: 91.87%
Valid Loss: 0.3419 | Valid Acc: 88.61%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 7/15 | Time: 585.49s | LR: 0.001000
Train Loss: 0.2019 | Train Acc: 92.74%
Valid Loss: 0.3665 | Valid Acc: 88.23%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 8/15 | Time: 588.24s | LR: 0.001000
Train Loss: 0.1811 | Train Acc: 93.57%
Valid Loss: 0.3916 | Valid Acc: 87.94%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 9/15 | Time: 593.59s | LR: 0.001000
Train Loss: 0.1661 | Train Acc: 94.05%
Valid Loss: 0.3884 | Valid Acc: 88.41%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

Validation:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 10/15 | Time: 594.28s | LR: 0.000500
Train Loss: 0.1487 | Train Acc: 94.70%
Valid Loss: 0.3915 | Valid Acc: 88.64%
--------------------------------------------------


Training:   0%|          | 0/704 [00:00<?, ?it/s]

In [None]:
# ## Plotting Training Curves

# Plot training curves
plt.figure(figsize=(15, 10))

# Loss subplot
plt.subplot(2, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

# Accuracy subplot
plt.subplot(2, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(valid_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracy')

# Learning rate subplot
plt.subplot(2, 2, 3)
plt.plot(lrs)
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')

plt.tight_layout()
plt.savefig('efficientnet_training_curves.png')
plt.show()

In [None]:
# ## Evaluate on Test Set

def evaluate(model, test_loader, device):
    """Evaluate model on test set with per-class accuracy"""
    model.eval()
    correct = 0
    total = 0
    
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            for i in range(labels.size(0)):
                label = labels[i]
                pred = predicted[i]
                if label == pred:
                    class_correct[label] += 1
                class_total[label] += 1
    
    # Print overall accuracy
    print(f"Test Accuracy: {100 * correct / total:.2f}%")
    
    # Print per-class accuracy
    print("\nPer-class accuracy:")
    for i in range(num_classes):
        if class_total[i] > 0:
            print(f"Accuracy of {class_names[i]}: {100 * class_correct[i] / class_total[i]:.2f}%")
    
    # Return overall accuracy
    return 100 * correct / total

print("\nEvaluating on test set...")
test_accuracy = evaluate(model, test_loader, device)

In [None]:
# ## Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(model, test_loader, class_names, device):
    """Create and plot confusion matrix"""
    # Get predictions
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Collecting predictions"):
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()

In [None]:


# Plot confusion matrix
try:
    plot_confusion_matrix(model, test_loader, class_names, device)
except ImportError:
    print("Skipping confusion matrix - seaborn or scikit-learn not installed.")
    print("Install with: pip install seaborn scikit-learn")

# ## Save the Model
# Save the trained model
model_save_path = 'efficientnet_model.pth'
torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_acc': train_accuracies[-1],
    'valid_acc': valid_accuracies[-1],
    'test_acc': test_accuracy,
}, model_save_path)

print(f"Model saved to '{model_save_path}'")

# ## How to Load and Use the Model Later
print("\nTo load and use this model later, use the following code:")
print("""
# Load the model
import torch
from torchvision.models import efficientnet_b0

# Create model architecture
model = efficientnet_b0(weights=None)
model.classifier = nn.Sequential(
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1280, out_features=10),
)

# Load saved weights
checkpoint = torch.load('efficientnet_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # Set to evaluation mode

# Use the model
# image = preprocess_image(your_image)  # Apply same transforms as during evaluation
# with torch.no_grad():
#     outputs = model(image)
#     _, predicted = torch.max(outputs, 1)
#     print(f"Predicted class: {class_names[predicted.item()]}")
""")