<a href="https://www.kaggle.com/code/louisdelignac/fruits-classification?scriptVersionId=249479014" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.io as tv_io

import glob
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

True

In [2]:
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

weights = VGG16_Weights.DEFAULT
vgg_model = vgg16(weights=weights)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 173MB/s]  


In [3]:
# Freeze base model
vgg_model.requires_grad_(False)
next(iter(vgg_model.parameters())).requires_grad

False

In [4]:
vgg_model.classifier[0:3]

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
)

In [5]:
N_CLASSES = 6

model = nn.Sequential(
    vgg_model.features,
    vgg_model.avgpool,
    nn.Flatten(),
    vgg_model.classifier[0:3],
    nn.Linear(4096, 500),
    nn.ReLU(),
    nn.Linear(500, N_CLASSES)
)
model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
loss_function = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())
model = torch.compile(model.to(device))

In [7]:
pre_trans = weights.transforms()

In [8]:
IMG_WIDTH, IMG_HEIGHT = (224, 224)

random_trans = transforms.Compose([
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.8, 1), ratio=(1, 1)),
    transforms.RandomHorizontalFlip()
])

In [9]:
DATA_LABELS = ["freshapples", "freshbanana", "freshoranges", "rottenapples", "rottenbanana", "rottenoranges"] 
    
class FruitDataset(Dataset):
    def __init__(self, data_dir):
        self.imgs = []
        self.labels = []
        
        for l_idx, label in enumerate(DATA_LABELS):
            data_paths = glob.glob(data_dir + label + '/*.png', recursive=True)
            for path in data_paths:
                img = tv_io.read_image(path, tv_io.ImageReadMode.RGB)
                self.imgs.append(pre_trans(img).to(device))
                self.labels.append(torch.tensor(l_idx).to(device))


    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        return img, label

    def __len__(self):
        return len(self.imgs)

In [10]:
n = 32

KAGGLE_PATH = "/kaggle/input/fruits-fresh-and-rotten-for-classification/"

train_path = KAGGLE_PATH + "dataset/train/"
train_data = FruitDataset(train_path)
train_loader = DataLoader(train_data, batch_size=n, shuffle=True)
train_N = len(train_loader.dataset)

valid_path = KAGGLE_PATH + "dataset/test/"
valid_data = FruitDataset(valid_path)
valid_loader = DataLoader(valid_data, batch_size=n, shuffle=False)
valid_N = len(valid_loader.dataset)

In [11]:
def get_batch_accuracy(output, y, N):
    pred = output.argmax(dim=1, keepdim=True)
    correct = pred.eq(y.view_as(pred)).sum().item()
    return correct / N

In [12]:
def train(model, train_loader, train_N, random_trans, optimizer, loss_function):
    loss = 0
    accuracy = 0

    model.train()
    for x, y in train_loader:
        output = model(random_trans(x))
        optimizer.zero_grad()
        batch_loss = loss_function(output, y)
        batch_loss.backward()
        optimizer.step()

        loss += batch_loss.item()
        accuracy += get_batch_accuracy(output, y, train_N)
    print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

In [13]:
def validate(model, valid_loader, valid_N, loss_function):
    loss = 0
    accuracy = 0

    model.eval()
    with torch.no_grad():
        for x, y in valid_loader:
            output = model(x)

            loss += loss_function(output, y).item()
            accuracy += get_batch_accuracy(output, y, valid_N)
    print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

In [14]:
epochs = 10

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    train(model, train_loader, train_N, random_trans, optimizer, loss_function)
    validate(model, valid_loader, valid_N, loss_function)

Epoch: 0


W0708 20:35:08.559000 36 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


Train - Loss: 56.2734 Accuracy: 0.9424
Valid - Loss: 6.7579 Accuracy: 0.9711
Epoch: 1
Train - Loss: 27.2212 Accuracy: 0.9716
Valid - Loss: 7.8371 Accuracy: 0.9733
Epoch: 2
Train - Loss: 22.9996 Accuracy: 0.9760
Valid - Loss: 7.5190 Accuracy: 0.9733
Epoch: 3
Train - Loss: 19.2102 Accuracy: 0.9800
Valid - Loss: 3.8445 Accuracy: 0.9837
Epoch: 4
Train - Loss: 18.6357 Accuracy: 0.9816
Valid - Loss: 5.4662 Accuracy: 0.9804
Epoch: 5
Train - Loss: 14.3902 Accuracy: 0.9850
Valid - Loss: 3.8091 Accuracy: 0.9844
Epoch: 6
Train - Loss: 14.1988 Accuracy: 0.9857
Valid - Loss: 1.7513 Accuracy: 0.9915
Epoch: 7
Train - Loss: 11.8433 Accuracy: 0.9884
Valid - Loss: 5.8043 Accuracy: 0.9804
Epoch: 8
Train - Loss: 15.1144 Accuracy: 0.9850
Valid - Loss: 6.0200 Accuracy: 0.9785
Epoch: 9
Train - Loss: 10.4516 Accuracy: 0.9899
Valid - Loss: 2.2172 Accuracy: 0.9922


In [15]:
vgg_model.requires_grad_(True)
optimizer = Adam(model.parameters(), lr=.0001)

In [17]:
epochs = 1

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    train(model, train_loader, train_N, random_trans, optimizer, loss_function)
    validate(model, valid_loader, valid_N, loss_function)

Epoch: 0
Train - Loss: 5.9747 Accuracy: 0.9942
Valid - Loss: 1.7032 Accuracy: 0.9930


In [18]:
validate(model, valid_loader, valid_N, loss_function)

Valid - Loss: 1.7032 Accuracy: 0.9930
