In [None]:
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = './dataset/train/'
TEST_DIR = './dataset/test/'
MODEL_PATH = './model/'

CLASSES = ['apple', 'banana', 'mango', 'orange']
NUM_CLASSES = len(CLASSES)

# Hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 0.001

In [None]:
def create_dataset(dataset_dir, mean=None, std=None, batch_size=32, shuffle=False):
    transform_list = [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
    if mean is not None and std is not None:
        transform_list.append(transforms.Normalize(mean=mean, std=std))
    
    transform = transforms.Compose(transform_list)
    dataset = ImageFolder(root=dataset_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    
    return dataset, loader

In [None]:
train_dataset, train_loader = create_dataset(
    dataset_dir=TRAIN_DIR,
    batch_size=BATCH_SIZE,
)

data = next(iter(train_dataset))
tensor = data[0].unsqueeze(0)
mean = tensor.mean((0, 2, 3))
std = tensor.std((0, 2, 3))

print(f'Mean: {mean}')
print(f'Std: {std}')

train_dataset, train_loader = create_dataset(
    dataset_dir=TRAIN_DIR,
    mean=mean,
    std=std,
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_dataset, test_loader = create_dataset(
    dataset_dir=TEST_DIR,
    mean=mean,
    std=std,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [None]:
conv1 = nn.Conv2d(3, 16, 3)
pool = nn.MaxPool2d(2, 2)
conv2 = nn.Conv2d(16, 32, 5)
conv3 = nn.Conv2d(32, 64, 5)
inputs = torch.randn((1, 3, 224, 224))
x = conv3(pool(conv2(pool(conv1(inputs)))))
outputs = F.avg_pool2d(x, kernel_size=x.size()[2:])

print(outputs.shape)

In [None]:
class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.pool = nn.MaxPool2d(2, 2)
    self.conv1 = nn.Conv2d(3, 64, 3)
    self.conv2 = nn.Conv2d(64, 128, 3)
    self.fc1 = nn.Linear(128*54*54, 1024)
    self.fc2 = nn.Linear(1024, NUM_CLASSES)

  def forward(self, x):
                                          # -> n, 4, 224, 224
    x = self.pool(F.relu(self.conv1(x)))  # -> n, 32, 111, 111
    x = self.pool(F.relu(self.conv2(x)))  # -> n, 32, 54, 54
    x = x.view(-1, 128*54*54)              # -> n, 93321
    x = F.relu(self.fc1(x))               # -> n, 128
    x = self.fc2(x)                       # -> n, 4
    return x
  
model = ConvNet().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

In [None]:
num_batches = len(train_loader)
for epoch in range(NUM_EPOCHS):
  for i, (images, labels) in enumerate(train_loader):
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)
    
    # Forward
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print (f'Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{i+1}/{num_batches}], Loss: {loss.item():.4f}')

In [None]:
torch.save(
  model.state_dict(), 
  f'{MODEL_PATH}custom_cnn_fruit_dataset_{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")}.pth'
)

In [None]:
with torch.no_grad():
  n_correct = 0
  n_samples = 0
  n_class_correct = [0 for i in range(10)]
  n_class_samples = [0 for i in range(10)]
  
  for images, labels in test_loader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)
    
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    
    n_samples += labels.size(0)
    n_correct += (predicted == labels).sum().item()
    
    for i in range(len(test_loader)):
      label = labels[i]
      pred = predicted[i]
      if (label == pred):
        n_class_correct[label] += 1
      n_class_samples[label] += 1
    
  acc = 100.0 * n_correct / n_samples
  print(f'Accuracy of the network: {acc} %')
  
  for i in range(NUM_CLASSES):
    acc = 0
    if n_class_samples[i] != 0:
      acc = 100.0 * n_class_correct[i] / n_class_samples[i]
    print(f'Accuracy of {CLASSES[i]}: {acc} %')