In [None]:
from torchvision import datasets, transforms,models
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

Load the Training images

In [None]:
def ignore_masks(path):
    return "mask" not in path.lower()

# Data Preparation for Fruit Multi-Class Classification

In [None]:
fruit_multiclass_train = datasets.ImageFolder(
    root=r"/kaggle/input/food-fruit-dataset/Project Data/Fruit/Train",
    transform=transforms,
    is_valid_file=ignore_masks
)

fruit_multiclass_val = datasets.ImageFolder(
    root=r"/kaggle/input/food-fruit-dataset/Project Data/Fruit/Validation",
    transform=transforms,
    is_valid_file=ignore_masks
)

# 4. Create the Loaders
fruit_train_loader = DataLoader(fruit_multiclass_train, batch_size=32, shuffle=True)
fruit_val_loader = DataLoader(fruit_multiclass_val, batch_size=32, shuffle=False)

### Visualize a batch

In [None]:
fruit_names = fruit_multiclass_train.classes

data_iter = iter(fruit_train_loader)
images, labels = next(data_iter)

fig = plt.figure(figsize=(12, 12))

for i in range(16):
    ax = fig.add_subplot(4, 4, i + 1, xticks=[], yticks=[])
    
    img_display = images[i].numpy().transpose((1, 2, 0))
    plt.imshow(img_display)
    idx = labels[i].item()
    real_name = fruit_names[idx]
    
    ax.set_title(f"{idx}: {real_name}", color="green")

plt.show()

# **training**

In [None]:
model = models.resnet18(pretrained=True)

# freeze all layers
for param in model.parameters():
    param.requires_grad = False

# replace the last layer with the 30 class
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 30)

model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):

    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    for images, labels in fruit_train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        z, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / len(fruit_multiclass_train)
    epoch_acc = running_corrects.double() / len(fruit_multiclass_train)
    
    # Validation 
    model.eval()
    val_corrects = 0
    with torch.no_grad():
        for images, labels in fruit_val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            z, preds = torch.max(outputs, 1)
            val_corrects += torch.sum(preds == labels.data)
    
    val_acc = val_corrects.double() / len(fruit_multiclass_val)
    
    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train Loss: {epoch_loss:.4f} "
          f"Train Acc: {epoch_acc:.4f} "
          f"Val Acc: {val_acc:.4f}")

In [None]:
torch.save(model.state_dict(), "fruit_model.pth")


 # **Testing** 

In [None]:
def test_image(img_path, model_path="fruit_model.pth", class_names=fruit_names):
    
    # Load image
    img = Image.open(img_path).convert("RGB")
    img_tensor = transforms(img).unsqueeze(0).to(device)
    
    # Load model
    model_test = models.resnet18(weights=None)
    num_ftrs = model_test.fc.in_features
    model_test.fc = nn.Linear(num_ftrs, 30)
    
    model_test.load_state_dict(torch.load(model_path, map_location=device))
    model_test.to(device)
    model_test.eval()
    
    # Predict
    with torch.no_grad():
        outputs = model_test(img_tensor)
        z, predict = torch.max(outputs, 1)

    predicted_class = class_names[predict.item()]
    print(f"Predicted Class: {predicted_class}")
    


In [None]:
test_image("/kaggle/input/tst-imgg/tofa7a.jpg")