In [14]:
# !pip install torch torchvision torchaudio --upgrade --quiet

import os, math, time, copy, random
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

DATA_ROOT = "dataset"              # adjust if needed
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
TEST_DIR  = os.path.join(DATA_ROOT, "test")
SAVE_BEST = "best_fruit_model.pt"
NUM_WORKERS = min(8, os.cpu_count() or 1)
PIN_MEM = torch.cuda.is_available()
BATCH_SIZE = 32
IMG_SIZE = 224


Device: cpu


In [15]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE*1.15)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tfms)
test_ds  = datasets.ImageFolder(TEST_DIR,  transform=test_tfms)
class_names = train_ds.classes
num_classes = len(class_names)
class_names, num_classes


(['banana', 'blueberry', 'pear'], 3)

In [16]:
# Compute per-sample weights for WeightedRandomSampler
class_count = np.bincount([y for _, y in train_ds.samples])
class_weights = 1.0 / np.maximum(class_count, 1)
sample_weights = [class_weights[y] for _, y in train_ds.samples]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEM)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEM)

len(train_loader), len(test_loader), class_count


(22, 9, array([215, 268, 219], dtype=int64))

In [17]:
# Option A: EfficientNet-B0 (recommended)
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)

model = model.to(device)


In [21]:
EPOCHS = 5
base_lr = 3e-4
weight_decay = 1e-4

# Label smoothing makes training more robust on small datasets
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)

optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)

# OneCycle schedules a higher peak LR, usually converges faster/better
steps_per_epoch = len(train_loader)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=base_lr*10, epochs=EPOCHS, steps_per_epoch=steps_per_epoch,
    pct_start=0.15, anneal_strategy='cos', div_factor=10.0, final_div_factor=1e4
)

scaler = torch.amp.GradScaler(enabled="cuda" if torch.cuda.is_available() else "cpu")


In [25]:
def run_epoch(model, loader, train=True):
    model.train(train)
    running_loss, correct, total = 0.0, 0, 0
    for images, targets in loader:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            outputs = model(images)
            loss = criterion(outputs, targets)
        if train:
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(1)
        correct += (preds == targets).sum().item()
        total += images.size(0)
    return running_loss/total, correct/total

best_acc, best_state, patience, patience_ctr = 0.0, None, 5, 0

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss, train_acc = run_epoch(model, train_loader, train=True)
    test_loss,  test_acc  = run_epoch(model, test_loader,  train=False)

    if test_acc > best_acc:
        best_acc = test_acc
        best_state = copy.deepcopy(model.state_dict())
        torch.save({'state_dict': best_state, 'classes': class_names}, SAVE_BEST)
        patience_ctr = 0
    else:
        patience_ctr += 1

    dt = time.time() - t0
    print(f"Epoch {epoch:02d}/{EPOCHS} \n"
          f"| train loss: {train_loss * 100:.4f}%, accuracy: {train_acc * 100:.2f}% \n"
          f"| test loss: {test_loss * 100:.4f}%, accuracy: {test_acc * 100:.2f}% \n"
          f"| best test accuracy {best_acc * 100:.4f}% | {dt:.1f}s")

    if patience_ctr >= patience:
        print("Early stopping.")
        break

# Load best weights
if best_state is not None:
    model.load_state_dict(best_state)


Epoch 01/5 
| train loss: 51.9623%, accuracy: 86.04% 
| test loss: 86.6204%, accuracy: 82.16% 
| best test accuracy 82.1561% | 92.5s
Epoch 02/5 
| train loss: 40.2459%, accuracy: 92.17% 
| test loss: 70.3333%, accuracy: 85.87% 
| best test accuracy 85.8736% | 92.6s
Epoch 03/5 
| train loss: 32.3175%, accuracy: 93.16% 
| test loss: 41.8233%, accuracy: 92.19% 
| best test accuracy 92.1933% | 95.1s
Epoch 04/5 
| train loss: 25.1073%, accuracy: 97.01% 
| test loss: 40.5851%, accuracy: 90.71% 
| best test accuracy 92.1933% | 101.1s
Epoch 05/5 
| train loss: 21.4598%, accuracy: 99.43% 
| test loss: 40.9261%, accuracy: 90.33% 
| best test accuracy 92.1933% | 104.4s


In [26]:
# Collect predictions
all_preds, all_targets = [], []
model.eval()
with torch.inference_mode():
    for images, targets in test_loader:
        images = images.to(device, non_blocking=True)
        logits = model(images)
        preds = logits.argmax(1).cpu().numpy()
        all_preds.append(preds)
        all_targets.append(targets.numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

print(classification_report(all_targets, all_preds, target_names=class_names, digits=4))
print("Confusion matrix:\n", confusion_matrix(all_targets, all_preds))


              precision    recall  f1-score   support

      banana     0.8864    0.9176    0.9017        85
   blueberry     0.9892    0.9684    0.9787        95
        pear     0.8864    0.8764    0.8814        89

    accuracy                         0.9219       269
   macro avg     0.9207    0.9208    0.9206       269
weighted avg     0.9227    0.9219    0.9222       269

Confusion matrix:
 [[78  0  7]
 [ 0 92  3]
 [10  1 78]]
