In [5]:
import time
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
# For torchvision >=0.13, you can use the new-style weights import:
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights

import matplotlib.pyplot as plt

# --------------------------
# Check MPS (Apple Silicon) Availability
# --------------------------
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# --------------------------
# Hyperparameters
# --------------------------
IMAGE_SIZE = 300
BATCH_SIZE = 32
EPOCHS = 40

FREEZE_EPOCH = 5     # epoch to partially unfreeze
UNFREEZE_EPOCH = 15  # epoch to unfreeze more layers

MIXUP_ALPHA = 0.4
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)

# --------------------------
# Mixup Utilities
# --------------------------
def mixup_data(inputs, labels, alpha=1.0):
    """Returns mixed inputs, pairs of targets, and lambda."""
    if alpha <= 0:
        return inputs, labels, labels, 1.0

    batch_size = inputs.size(0)
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(batch_size).to(inputs.device)

    mixed_inputs = lam * inputs + (1 - lam) * inputs[index, :]
    labels_a, labels_b = labels, labels[index]
    return mixed_inputs, labels_a, labels_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# --------------------------
# Data Augmentation
# --------------------------
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE + 32),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dir = "data/train"
test_dir  = "data/test"

image_datasets = {
    'train': datasets.ImageFolder(train_dir, transform=train_transforms),
    'test':  datasets.ImageFolder(test_dir, transform=test_transforms)
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes
num_classes = len(class_names)

# --------------------------
# DataLoaders
#   - MPS doesn't benefit much from pin_memory=True, so set it to False
#   - Adjust num_workers to your CPU
# --------------------------
dataloaders = {
    'train': DataLoader(
        image_datasets['train'], batch_size=BATCH_SIZE,
        shuffle=True, num_workers=4, pin_memory=False
    ),
    'test': DataLoader(
        image_datasets['test'], batch_size=BATCH_SIZE,
        shuffle=False, num_workers=4, pin_memory=False
    )
}

print(f"Found {dataset_sizes['train']} training images across {num_classes} classes.")
print(f"Found {dataset_sizes['test']} test images.")

# --------------------------
# Model Setup (EfficientNet B3)
# If you have an older torchvision version, replace with: efficientnet_b3(pretrained=True)
# --------------------------
model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)

# Number of features in the last linear layer
num_ftrs = model.classifier[1].in_features

# Replace the classifier
model.classifier = nn.Sequential(
    nn.Linear(num_ftrs, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.4),
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(inplace=True),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes)
)

model = model.to(device)

# --------------------------
# Freeze Strategy (initially freeze all features)
# --------------------------
for param in model.features.parameters():
    param.requires_grad = False

# --------------------------
# Initial Optimizer & Scheduler (only classifier)
# --------------------------
optimizer = optim.AdamW(model.classifier.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    steps_per_epoch=len(dataloaders['train']),
    epochs=EPOCHS
)

# --------------------------
# Training Function
# --------------------------
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=EPOCHS):
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    train_losses, test_losses = [], []
    train_accuracies, test_accuracies = [], []

    start_time = time.time()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs-1}")
        print('-' * 10)

        # ------------------------------
        # Progressive Unfreezing Steps
        # ------------------------------
        if epoch == FREEZE_EPOCH:
            print("Unfreezing top layers (partial)...")
            # Unfreeze the last half of model.features blocks
            total_blocks = len(model.features)
            blocks_to_unfreeze = total_blocks // 2
            for i in range(blocks_to_unfreeze, total_blocks):
                for param in model.features[i].parameters():
                    param.requires_grad = True

            # Redefine optimizer & scheduler
            optimizer = optim.AdamW(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=5e-4, weight_decay=1e-5
            )
            scheduler = OneCycleLR(
                optimizer,
                max_lr=5e-4,
                steps_per_epoch=len(dataloaders['train']),
                epochs=num_epochs - FREEZE_EPOCH,
                pct_start=0.3
            )

        if epoch == UNFREEZE_EPOCH:
            print("Unfreezing remaining layers (full fine-tuning)...")
            for param in model.features.parameters():
                param.requires_grad = True

            optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
            scheduler = OneCycleLR(
                optimizer,
                max_lr=1e-4,
                steps_per_epoch=len(dataloaders['train']),
                epochs=num_epochs - UNFREEZE_EPOCH,
                pct_start=0.3
            )

        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                # ----------------
                # Mixup for training
                # ----------------
                if phase == 'train':
                    inputs_mixed, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha=MIXUP_ALPHA)
                    outputs = model(inputs_mixed)
                    loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]

            # -- Fix: use float() not double() for MPS --
            epoch_acc = (running_corrects.float() / dataset_sizes[phase]).item()

            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accuracies.append(epoch_acc)
            else:
                test_losses.append(epoch_loss)
                test_accuracies.append(epoch_acc)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Track best model
            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - start_time
    print(f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best test Acc: {best_acc:.4f}")

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, (train_losses, test_losses, train_accuracies, test_accuracies)

# --------------------------
# Train the model
# --------------------------
model, history = train_model(
    model=model,
    dataloaders=dataloaders,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=EPOCHS
)

train_losses, test_losses, train_accuracies, test_accuracies = history

# --------------------------
# Save best model
# --------------------------
torch.save(model.state_dict(), "plant_disease_efficientnet_b3_mixup_best_mps.pth")

# --------------------------
# Visualization of Predictions
# --------------------------
def visualize_model(model, dataloader, class_names, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    plt.figure()

    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size(0)):
                images_so_far += 1

                plt.subplot(num_images // 2, 2, images_so_far)
                plt.axis('off')
                plt.title(f'Pred: {class_names[preds[j]]}')

                # Un-normalize for display
                img = inputs[j].cpu().numpy().transpose((1, 2, 0))
                img = std * img + mean
                img = np.clip(img, 0, 1)
                plt.imshow(img)
                if images_so_far == num_images:
                    model.train(mode=was_training)
                    plt.show()
                    return
    model.train(mode=was_training)

print("Visualizing some predictions from the test set...")
visualize_model(model, dataloaders['test'], class_names)

# --------------------------
# Plot Accuracy and Loss
# --------------------------
plt.figure()
plt.plot(range(EPOCHS), train_accuracies, label='Train Accuracy')
plt.plot(range(EPOCHS), test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy')
plt.legend()
plt.show()

plt.figure()
plt.plot(range(EPOCHS), train_losses, label='Train Loss')
plt.plot(range(EPOCHS), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.show()


Using device: mps
Found 2336 training images across 28 classes.
Found 236 test images.
Epoch 0/39
----------
train Loss: 3.3815 Acc: 0.0458
test Loss: 3.2355 Acc: 0.1144

Epoch 1/39
----------
train Loss: 3.1510 Acc: 0.0989
test Loss: 3.0376 Acc: 0.1907

Epoch 2/39
----------
train Loss: 2.8991 Acc: 0.1438
test Loss: 2.7961 Acc: 0.3008

Epoch 3/39
----------
train Loss: 2.6699 Acc: 0.1742
test Loss: 2.5850 Acc: 0.3305

Epoch 4/39
----------
train Loss: 2.4488 Acc: 0.2299
test Loss: 2.3937 Acc: 0.4195

Epoch 5/39
----------
Unfreezing top layers (partial)...
