In [119]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchmetrics
from torchvision import datasets, transforms

from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import random
from PIL import Image
from matplotlib import pyplot as plt
import seaborn as sns

import sys 
import os
from pathlib import Path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)
from torch_device import DEVICE
DEVICE

'cuda'

#### Checking Directory and Subdirectory Structure

In [120]:
!tree /A

Folder PATH listing for volume Data
Volume serial number is 3209-780B
D:.
\---GuavaDiseaseDataset
    +---test
    |   +---Anthracnose
    |   +---fruit_fly
    |   \---healthy_guava
    +---train
    |   +---Anthracnose
    |   +---fruit_fly
    |   \---healthy_guava
    \---val
        +---Anthracnose
        +---fruit_fly
        \---healthy_guava


In [121]:
data_path = Path("GuavaDiseaseDataset")

In [122]:
dir_path = data_path
for dirpath, dirnames, filenames in os.walk(dir_path):
    if len(filenames) > 0:
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

There are 3 directories and 1 images in 'GuavaDiseaseDataset'.
There are 0 directories and 156 images in 'GuavaDiseaseDataset\test\Anthracnose'.
There are 0 directories and 132 images in 'GuavaDiseaseDataset\test\fruit_fly'.
There are 0 directories and 94 images in 'GuavaDiseaseDataset\test\healthy_guava'.
There are 0 directories and 1080 images in 'GuavaDiseaseDataset\train\Anthracnose'.
There are 0 directories and 918 images in 'GuavaDiseaseDataset\train\fruit_fly'.
There are 0 directories and 649 images in 'GuavaDiseaseDataset\train\healthy_guava'.
There are 0 directories and 308 images in 'GuavaDiseaseDataset\val\Anthracnose'.
There are 0 directories and 262 images in 'GuavaDiseaseDataset\val\fruit_fly'.
There are 0 directories and 185 images in 'GuavaDiseaseDataset\val\healthy_guava'.


#### Train & Test Paths

In [123]:
train_dir = data_path / "train"
test_dir = data_path / "test"
train_dir, test_dir

(WindowsPath('GuavaDiseaseDataset/train'),
 WindowsPath('GuavaDiseaseDataset/test'))

#### Helper Function To Display An Image

In [124]:
def DisplayImage(img : Image, title : str = None):
    plt.imshow(img)
    if title != None:
        plt.title(title)
    plt.axis("off")
    plt.show()

#### Loading Train & Test Images Into A DataLoader

In [125]:
train_transform = transforms.Compose([
    # transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    # transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

train_transform, test_transform

(Compose(
     Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=True)
     RandomHorizontalFlip(p=0.5)
     ToTensor()
 ),
 Compose(
     Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=True)
     ToTensor()
 ))

In [126]:
train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.ImageFolder(root=test_dir, transform=train_transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)


In [127]:
classes = train_dataset.classes
class_dict = train_dataset.class_to_idx
classes, class_dict

(['Anthracnose', 'fruit_fly', 'healthy_guava'],
 {'Anthracnose': 0, 'fruit_fly': 1, 'healthy_guava': 2})

#### Model Setup (CNN)

In [128]:
torch.manual_seed(24)

<torch._C.Generator at 0x18eddf6d3f0>

In [129]:
class Guava_CNN(nn.Module):
    def __init__(self, input_shape, hidden_units, output_shape, conv_k_size = 3, conv_stride = 1, pool_k_size = 2, pool_stride =2):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=conv_k_size, stride=conv_stride, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=hidden_units),
            nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=conv_k_size, stride=conv_stride, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool_k_size, stride=pool_stride)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=conv_k_size, stride=conv_stride, padding=1),
            nn.ReLU(),
            # nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=conv_k_size, stride=conv_stride, padding=1),
            # nn.ReLU(),
            nn.MaxPool2d(kernel_size=pool_k_size, stride=pool_stride)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units * 64*64 # Change!!!
                      ,out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=32),
            nn.Linear(in_features=32, out_features=3)

        )

    def forward(self, x):
        y_b1 = self.block1(x)
        # print(f"--------\n{y_b1.shape=}")
        y_b2 = self.block2(y_b1)
        # print(f"{y_b2.shape=}")
        y = self.classifier(y_b2)
        # print(f"{y.shape=}\n--------")
        return y

cnn = Guava_CNN(input_shape=3, hidden_units=10, output_shape=len(classes)).to(DEVICE)
total_parameteres = sum(p.numel() for p in cnn.parameters())
cnn, f"{total_parameteres=}"

(Guava_CNN(
   (block1): Sequential(
     (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): ReLU()
     (2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (4): ReLU()
     (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   )
   (block2): Sequential(
     (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): ReLU()
     (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   )
   (classifier): Sequential(
     (0): Flatten(start_dim=1, end_dim=-1)
     (1): Linear(in_features=40960, out_features=128, bias=True)
     (2): ReLU()
     (3): Linear(in_features=128, out_features=64, bias=True)
     (4): ReLU()
     (5): Linear(in_features=64, out_features=32, bias=True)
     (6): Linear(in_features=32, out_features=3, bias=True)
   )
 ),
 'total_parameteres=5255563

In [130]:
for batch, (X,y) in enumerate(train_dataloader):
    X, y = X.to(DEVICE), y.to(DEVICE)

    y_pred = cnn.forward(X)
    break

#### Loss Function (CEL) and Optimizer (SGD)

In [131]:
loss_fn_cnn = nn.CrossEntropyLoss()
optimizer_cnn = torch.optim.Adam(params=cnn.parameters(), lr=1e-3)

#### Train & Test Steps

In [132]:
def train_step(model : nn.Module , dataloader: torch.utils.data.DataLoader,
            loss_func : nn.Module, optimizer : torch.optim.Optimizer, num_classes,
            accuracy_func = torchmetrics.functional.accuracy, device = DEVICE):
    train_loss, train_acc = 0, 0
    model.to(device=device)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        y_pred = model.forward(X)

        batch_loss = loss_func(y_pred, y)
        batch_acc = accuracy_func(preds=y_pred, target=y,task='multiclass', num_classes=num_classes)
        if batch % (int(len(dataloader) / 10)) == 0:
            print(f"\tBatch {batch} | Loss: {batch_loss} | Accuracy: {batch_acc}")

        train_loss += batch_loss 
        train_acc += batch_acc

        optimizer.zero_grad()

        batch_loss.backward()

        optimizer.step()

    train_loss /= len(dataloader)
    train_acc /= len(dataloader)

    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}")
    
def test_step(model: nn.Module, dataloader : torch.utils.data.DataLoader, 
              loss_func: nn.Module, num_classes, accuracy_func: torchmetrics.functional.accuracy, device = DEVICE):
    test_loss, test_acc = 0, 0
    model.to(device=DEVICE)
    model.eval()
    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            y_pred = model.forward(X)

            test_loss += loss_func(y_pred, y)
            test_acc += accuracy_func(preds=y_pred, target=y,task='multiclass', num_classes=num_classes)

        test_loss /= len(dataloader)
        test_acc /= len(dataloader)

        print(f"Test loss: {test_loss:.5f} | Test   accuracy: {test_acc:.2f}")

#### Training & Evaluating

In [133]:
epochs = 5
for epoch in tqdm(range(epochs)):
    print(f"Epoch {epoch}:")
    train_step(cnn, train_dataloader, loss_fn_cnn, optimizer_cnn, len(classes), device=DEVICE)
    test_step(cnn, test_dataloader, loss_fn_cnn, len(classes), torchmetrics.functional.accuracy)

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0:
	Batch 0 | Loss: 1.1211286783218384 | Accuracy: 0.28125
	Batch 8 | Loss: 0.43153515458106995 | Accuracy: 0.8125
	Batch 16 | Loss: 0.47175294160842896 | Accuracy: 0.78125
	Batch 24 | Loss: 0.5257145762443542 | Accuracy: 0.78125
	Batch 32 | Loss: 0.48481959104537964 | Accuracy: 0.78125
	Batch 40 | Loss: 0.1851448118686676 | Accuracy: 0.90625
	Batch 48 | Loss: 0.6364563703536987 | Accuracy: 0.8125
	Batch 56 | Loss: 0.09008537977933884 | Accuracy: 1.0
	Batch 64 | Loss: 0.2955690026283264 | Accuracy: 0.8125
	Batch 72 | Loss: 0.29318535327911377 | Accuracy: 0.90625
	Batch 80 | Loss: 0.20058035850524902 | Accuracy: 0.90625
Train loss: 0.41840 | Train accuracy: 0.83


 20%|██        | 1/5 [00:24<01:37, 24.50s/it]

Test loss: 0.34268 | Test   accuracy: 0.88
Epoch 1:
	Batch 0 | Loss: 0.2379794865846634 | Accuracy: 0.9375
	Batch 8 | Loss: 0.1517874300479889 | Accuracy: 0.96875
	Batch 16 | Loss: 0.38639748096466064 | Accuracy: 0.84375
	Batch 24 | Loss: 0.17329426109790802 | Accuracy: 0.96875
	Batch 32 | Loss: 0.4456913471221924 | Accuracy: 0.875
	Batch 40 | Loss: 0.4044700264930725 | Accuracy: 0.8125
	Batch 48 | Loss: 0.38312986493110657 | Accuracy: 0.8125
	Batch 56 | Loss: 0.493373841047287 | Accuracy: 0.84375
	Batch 64 | Loss: 0.11528097093105316 | Accuracy: 0.96875
	Batch 72 | Loss: 0.22918367385864258 | Accuracy: 0.875
	Batch 80 | Loss: 0.19230718910694122 | Accuracy: 0.90625
Train loss: 0.25819 | Train accuracy: 0.90


 40%|████      | 2/5 [00:49<01:13, 24.59s/it]

Test loss: 0.24352 | Test   accuracy: 0.91
Epoch 2:
	Batch 0 | Loss: 0.06823782622814178 | Accuracy: 1.0
	Batch 8 | Loss: 0.09553790837526321 | Accuracy: 0.96875
	Batch 16 | Loss: 0.3530217409133911 | Accuracy: 0.875
	Batch 24 | Loss: 0.18327775597572327 | Accuracy: 0.9375
	Batch 32 | Loss: 0.17086823284626007 | Accuracy: 0.90625
	Batch 40 | Loss: 0.18486663699150085 | Accuracy: 0.96875
	Batch 48 | Loss: 0.12216278165578842 | Accuracy: 0.9375
	Batch 56 | Loss: 0.16119451820850372 | Accuracy: 0.9375
	Batch 64 | Loss: 0.30955764651298523 | Accuracy: 0.875
	Batch 72 | Loss: 0.2625788450241089 | Accuracy: 0.875
	Batch 80 | Loss: 0.15167368948459625 | Accuracy: 0.96875
Train loss: 0.19026 | Train accuracy: 0.93


 60%|██████    | 3/5 [01:13<00:49, 24.50s/it]

Test loss: 0.20932 | Test   accuracy: 0.91
Epoch 3:
	Batch 0 | Loss: 0.09649941325187683 | Accuracy: 0.96875
	Batch 8 | Loss: 0.22563593089580536 | Accuracy: 0.875
	Batch 16 | Loss: 0.03310680761933327 | Accuracy: 1.0
	Batch 24 | Loss: 0.3322550058364868 | Accuracy: 0.875
	Batch 32 | Loss: 0.15670832991600037 | Accuracy: 0.9375
	Batch 40 | Loss: 0.09246756136417389 | Accuracy: 0.96875
	Batch 48 | Loss: 0.1990651935338974 | Accuracy: 0.9375
	Batch 56 | Loss: 0.2957402467727661 | Accuracy: 0.84375
	Batch 64 | Loss: 0.16950276494026184 | Accuracy: 0.90625
	Batch 72 | Loss: 0.13132794201374054 | Accuracy: 0.90625
	Batch 80 | Loss: 0.20033875107765198 | Accuracy: 0.96875
Train loss: 0.15468 | Train accuracy: 0.94


 80%|████████  | 4/5 [01:37<00:24, 24.48s/it]

Test loss: 0.14340 | Test   accuracy: 0.95
Epoch 4:
	Batch 0 | Loss: 0.20389433205127716 | Accuracy: 0.90625
	Batch 8 | Loss: 0.20085987448692322 | Accuracy: 0.96875
	Batch 16 | Loss: 0.2985471487045288 | Accuracy: 0.875
	Batch 24 | Loss: 0.05439038202166557 | Accuracy: 0.96875
	Batch 32 | Loss: 0.14723806083202362 | Accuracy: 0.9375
	Batch 40 | Loss: 0.11825314164161682 | Accuracy: 0.96875
	Batch 48 | Loss: 0.168707475066185 | Accuracy: 0.9375
	Batch 56 | Loss: 0.08076707273721695 | Accuracy: 0.96875
	Batch 64 | Loss: 0.040293581783771515 | Accuracy: 1.0
	Batch 72 | Loss: 0.02844599261879921 | Accuracy: 1.0
	Batch 80 | Loss: 0.2086792141199112 | Accuracy: 0.9375
Train loss: 0.11863 | Train accuracy: 0.95


100%|██████████| 5/5 [02:02<00:00, 24.50s/it]

Test loss: 0.23543 | Test   accuracy: 0.93





#### Saving & Loading

In [137]:
torch.save({
    'model_state_dict':cnn.state_dict(),
    'optimizer_state_dict':optimizer_cnn.state_dict()
}, 'guava_classifer_cnn.pth')

In [143]:
model = Guava_CNN(3, 10, 3)
model.load_state_dict(torch.load('guava_classifer_cnn.pth', weights_only=False)['model_state_dict'])
optimizer_cnn = torch.optim.Adam(params=model.parameters(), lr=1e-3)
optimizer_cnn.load_state_dict(torch.load('guava_classifer_cnn.pth', weights_only=False)['optimizer_state_dict'])
model.eval()

Guava_CNN(
  (block1): Sequential(
    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block2): Sequential(
    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=40960, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=32, bias=True)
    (6): Linear(in_features=32, out_features=3, bias=True)
  )
)