In [1]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
from torchvision import datasets , transforms
from torch.utils.data import DataLoader, random_split,Dataset
from PIL import Image
import os
import tqdm

In [7]:
class FruitFreshnessDataset(Dataset):
    def __init__(self , root_dir , transform = None): 
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.fruit_map = {'Apple': 0, 'Banana': 1, 'Strawberry': 2}
        self.freshness_map = {'Fresh': 0, 'Rotten': 1} 

        for fruit in os.listdir(root_dir): 
            fruit_path = os.path.join(root_dir , fruit)
            
            for freshness in os.listdir(fruit_path): 
                freshness_path = os.path.join(fruit_path , freshness)

                for image_file in os.listdir(freshness_path):
                    image_file_path = os.path.join(freshness_path , image_file)
                    self.images.append(image_file_path)
                    self.labels.append((self.fruit_map[fruit] , self.freshness_map[freshness]))

    def __len__(self):
        return len(self.images)
        
    def __getitem__(self , index): 
        image = Image.open(self.images[index]).convert("RGB")
        if self.transform: 
            image = self.transform(image)
        return image , self.labels[index]

In [8]:
root_dir = "Data/Fruit Freshness Dataset/"
os.listdir(root_dir)

['Apple', 'Banana', 'Strawberry']

In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), # Resize all images to 224x224
    transforms.RandomHorizontalFlip(), # augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

In [10]:
fruit_dataset = FruitFreshnessDataset(root_dir = root_dir , transform = transform)

In [11]:
print(f"Total images: {len(fruit_dataset)}")
print(f"Sample image & labels: {fruit_dataset[0][0].shape}, {fruit_dataset[0][1]}")

Total images: 566
Sample image & labels: torch.Size([3, 224, 224]), (0, 0)


In [12]:
total_size = len(fruit_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

In [13]:
train_dataset, val_dataset, test_dataset = random_split(fruit_dataset, [train_size, val_size, test_size])

In [14]:
batch_size = 32

In [15]:
# Define dataloader 
train_loader = DataLoader(train_dataset , batch_size = batch_size , shuffle = True)
val_loader = DataLoader(val_dataset , batch_size = batch_size , shuffle = False)
test_loader = DataLoader(test_dataset , batch_size = batch_size , shuffle = False)

In [16]:
import torchvision.models as models

In [17]:
vgg16 = models.vgg16(weights = models.VGG16_Weights.DEFAULT)

In [18]:
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [19]:
num_fruit_classes = 3
num_freshness_classes = 2

In [20]:
import torch.nn as nn

In [21]:
vgg16.classifier[0].in_features

25088

In [22]:
class MultiOutputVGG16(nn.Module): 
    def __init__(self , base_model, num_fruit_classes, num_freshness_classes): 
        super().__init__()
        self.features = base_model.features # keeping the convolution layers
        # Now freeze the cnn part 
        for param in self.features.parameters(): 
            param.requires_grad = False

        # Fully connect layers 
        in_features = base_model.classifier[0].in_features

        # Two heads: 1.Fruit type , 2.Freshness 
        self.fc_fruit = nn.Linear(in_features , num_fruit_classes)
        self.fc_freshness = nn.Linear(in_features , num_freshness_classes) 

    def forward(self , x): 
        x = self.features(x)
        x = x.view(x.size(0) , -1) # flatten the image
        fruit_out = self.fc_fruit(x)
        freshness_out = self.fc_freshness(x)
        return fruit_out , freshness_out

In [23]:
len(list(vgg16.children()))

3

In [24]:
list(vgg16.children())[2]

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

In [25]:
model = MultiOutputVGG16(
    base_model = vgg16,
    num_fruit_classes = num_fruit_classes,
    num_freshness_classes = num_freshness_classes
).to(device)

In [26]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)

In [27]:
num_epochs = 5
for epoch in range(num_epochs): 
    model.train()
    total_loss = 0 
    for images , labels in train_loader: 
        images = images.to(device)
        fruit_labels = labels[0].to(device)
        freshness_labels = labels[1].to(device)

        optimizer.zero_grad()
        fruit_out, freshness_out = model(images)
        loss = criterion(fruit_out, fruit_labels) + criterion(freshness_out, freshness_labels)
        loss.backward() 
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        
    print(f"Epoch: {epoch + 1}/{num_epochs} , Loss: {total_loss/len(train_loader.dataset):.4f}")

Epoch: 1/5 , Loss: 0.8439
Epoch: 2/5 , Loss: 0.1834
Epoch: 3/5 , Loss: 0.1059
Epoch: 4/5 , Loss: 0.0716
Epoch: 5/5 , Loss: 0.0526


In [28]:
model.eval()
correct_fruit = 0
correct_freshness = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        fruit_labels = labels[0].to(device)
        freshness_labels = labels[1].to(device)

        fruit_out, freshness_out = model(images)
        _, pred_fruit = torch.max(fruit_out, 1)
        _, pred_freshness = torch.max(freshness_out, 1)

        total += fruit_labels.size(0)
        correct_fruit += (pred_fruit == fruit_labels).sum().item()
        correct_freshness += (pred_freshness == freshness_labels).sum().item()

print(f"Fruit Accuracy: {100*correct_fruit/total:.2f}%")
print(f"Freshness Accuracy: {100*correct_freshness/total:.2f}%")


Fruit Accuracy: 100.00%
Freshness Accuracy: 92.86%


In [29]:
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    classification_report
)
import numpy as np

In [30]:
model.eval()
true_fruit = []
pred_fruit = []
true_freshness = []
pred_freshness = []

prob_freshness = []
prob_fruit = [] 


with torch.no_grad(): 
    for images , labels in val_loader: 
        images = images.to(device)
        fruit_labels = labels[0].to(device)
        freshness_labels = labels[1].to(device)

        fruit_out, freshness_out = model(images)

        fruit_preds = torch.argmax(fruit_out, dim=1)
        freshness_preds = torch.argmax(freshness_out, dim=1)

        fruit_probs = torch.softmax(fruit_out, dim=1)
        freshness_probs = torch.softmax(freshness_out, dim=1)

        true_fruit.extend(fruit_labels.cpu().numpy())
        pred_fruit.extend(fruit_preds.cpu().numpy())
        true_freshness.extend(freshness_labels.cpu().numpy())
        pred_freshness.extend(freshness_preds.cpu().numpy())

        prob_freshness.extend(freshness_probs[:, 1].cpu().numpy())

In [31]:
print("FRESHNESS METRICS")

print("Accuracy :", accuracy_score(true_freshness, pred_freshness))
print("Precision:", precision_score(true_freshness, pred_freshness))
print("Recall   :", recall_score(true_freshness, pred_freshness))
print("F1 Score :", f1_score(true_freshness, pred_freshness))

auc = roc_auc_score(true_freshness, prob_freshness)
print("AUC      :", auc)

FRESHNESS METRICS
Accuracy : 0.9285714285714286
Precision: 0.9642857142857143
Recall   : 0.84375
F1 Score : 0.9
AUC      : 0.9819711538461539


In [32]:
print("FRUIT METRICS")
print("Accuracy :", accuracy_score(true_fruit, pred_fruit))
print("Precision:", precision_score(true_fruit, pred_fruit, average="macro"))
print("Recall   :", recall_score(true_fruit, pred_fruit, average="macro"))
print("F1 Score :", f1_score(true_fruit, pred_fruit, average="macro"))

FRUIT METRICS
Accuracy : 1.0
Precision: 1.0
Recall   : 1.0
F1 Score : 1.0
