In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.optim import lr_scheduler
import json
import pickle

In [2]:
dataset_path = "FreshHarvest_Dataset/FRUIT-16K"

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [3]:
# 2. LOAD DATASETS WITH TRANSFORMS

full_dataset = datasets.ImageFolder(dataset_path)

# Train: 70%, Val: 15%, Test: 15%
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size]
)

train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

num_classes = len(full_dataset.classes)

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)


In [5]:
num_epochs = 3
best_accuracy = 0

for epoch in range(num_epochs):
    model.train()
    correct, total = 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_acc = 100 * correct / total

    # VALIDATION
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs} | Train: {train_acc:.2f}% | Val: {val_acc:.2f}%")

    if val_acc > best_accuracy:
        best_accuracy = val_acc
        torch.save(model.state_dict(), "best_resnet50.pth")
        print("ðŸ”¥ Saved new best model!")

    scheduler.step()

Epoch 1/3 | Train: 90.00% | Val: 98.21%
ðŸ”¥ Saved new best model!
Epoch 2/3 | Train: 98.54% | Val: 99.17%
ðŸ”¥ Saved new best model!
Epoch 3/3 | Train: 99.18% | Val: 99.17%


In [6]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"ðŸ”¥ Final Test Accuracy: {100 * correct / total:.2f}%")

ðŸ”¥ Final Test Accuracy: 99.25%


In [7]:
torch.save(model.state_dict(), "resnet50_fruit_model.pth")

with open("class_to_idx.json", "w") as f:
    json.dump(full_dataset.class_to_idx, f)

with open("preprocess.pkl", "wb") as f:
    pickle.dump(val_test_transform, f)

print("Model + class mapping + preprocess saved!")

Model + class mapping + preprocess saved!
