In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.io as tv_io

import glob

from src import utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

False


In [2]:
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

weights = VGG16_Weights.DEFAULT
vgg_model = vgg16(weights=weights)

In [3]:
# Freeze base model
vgg_model.requires_grad_(False)
next(iter(vgg_model.parameters())).requires_grad

False

In [4]:
vgg_model.classifier[0:3]

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
)

In [5]:
N_CLASSES = 6

my_model = nn.Sequential(
    vgg_model.features,
    vgg_model.avgpool,
    nn.Flatten(),
    vgg_model.classifier[0:3],
    nn.Linear(4096, 500),
    nn.ReLU(),
    nn.Linear(500, N_CLASSES)
)
my_model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
loss_function = nn.CrossEntropyLoss()
optimizer = Adam(my_model.parameters())
my_model = torch.compile(my_model.to(device))

In [7]:
pre_trans = weights.transforms()

In [8]:
IMG_WIDTH, IMG_HEIGHT = (224, 224)

random_trans = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.8, 1), ratio=(1, 1)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=.2, contrast=.2)
])

In [9]:
DATA_LABELS = ["freshapples", "freshbanana", "freshoranges", "rottenapples", "rottenbanana", "rottenoranges"] 
    
class MyDataset(Dataset):
    def __init__(self, data_dir):
        self.imgs = []
        self.labels = []
        
        for l_idx, label in enumerate(DATA_LABELS):
            data_paths = glob.glob(data_dir + label + '/*.png', recursive=True)
            for path in data_paths:
                img = tv_io.read_image(path, tv_io.ImageReadMode.RGB)
                self.imgs.append(pre_trans(img).to(device))
                self.labels.append(torch.tensor(l_idx).to(device))


    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        return img, label

    def __len__(self):
        return len(self.imgs)

In [11]:
n = 32

train_path = "data/fruits/train/"
train_data = MyDataset(train_path)
print(train_data)
train_loader = DataLoader(train_data, batch_size=n, shuffle=True)
train_N = len(train_loader.dataset)

valid_path = "data/fruits/valid/"
valid_data = MyDataset(valid_path)
valid_loader = DataLoader(valid_data, batch_size=n, shuffle=False)
valid_N = len(valid_loader.dataset)

<__main__.MyDataset object at 0x00000269A29142B0>


In [13]:
epochs = 10

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    utils.train(my_model, train_loader, train_N, random_trans, optimizer, loss_function)
    utils.validate(my_model, valid_loader, valid_N, loss_function)

Epoch: 0
Train - Loss: 23.0756 Accuracy: 0.7707
Valid - Loss: 2.4391 Accuracy: 0.9149
Epoch: 1
Train - Loss: 10.8086 Accuracy: 0.8875
Valid - Loss: 5.3398 Accuracy: 0.8632
Epoch: 2
Train - Loss: 10.6391 Accuracy: 0.9103
Valid - Loss: 1.7336 Accuracy: 0.9453
Epoch: 3
Train - Loss: 7.0152 Accuracy: 0.9239
Valid - Loss: 2.4444 Accuracy: 0.9179
Epoch: 4
Train - Loss: 6.3226 Accuracy: 0.9349
Valid - Loss: 4.0935 Accuracy: 0.8815
Epoch: 5
Train - Loss: 7.2802 Accuracy: 0.9332
Valid - Loss: 2.1131 Accuracy: 0.9301
Epoch: 6
Train - Loss: 4.8225 Accuracy: 0.9611
Valid - Loss: 3.6572 Accuracy: 0.9027
Epoch: 7
Train - Loss: 4.5255 Accuracy: 0.9585
Valid - Loss: 4.6094 Accuracy: 0.8967
Epoch: 8
Train - Loss: 5.6939 Accuracy: 0.9391
Valid - Loss: 3.3990 Accuracy: 0.8936
Epoch: 9
Train - Loss: 5.3016 Accuracy: 0.9492
Valid - Loss: 2.6197 Accuracy: 0.9210


In [14]:
# Unfreeze the base model
vgg_model.requires_grad_(True)
optimizer = Adam(my_model.parameters(), lr=.0001)

In [15]:
epochs = 1

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    utils.train(my_model, train_loader, train_N, random_trans, optimizer, loss_function)
    utils.validate(my_model, valid_loader, valid_N, loss_function)

Epoch: 0
Train - Loss: 16.8570 Accuracy: 0.8596
Valid - Loss: 2.1032 Accuracy: 0.9331


In [16]:
utils.validate(my_model, valid_loader, valid_N, loss_function)

Valid - Loss: 2.1032 Accuracy: 0.9331


In [22]:
torch.save(my_model.state_dict(), 'models/fruit_classifier_state_dict.pt')