In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from thesis.utils.pytorch import train

In [11]:
# Define data transforms
# MobileNetV2 expects 224x224 input images and specific normalization
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),          # phóng to rồi crop ngẫu nhiên vùng 224x224
        transforms.RandomHorizontalFlip(p=0.5),     # lật ngang ảnh ngẫu nhiên
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # thay đổi màu sắc
        transforms.RandomRotation(degrees=15),      # xoay ngẫu nhiên ±15 độ
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # dịch ảnh
        transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # hiệu ứng phối cảnh
        transforms.ToTensor(),                      # chuyển thành tensor
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([ # Often same as validation for consistent evaluation
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


base_dir = "/Users/anand/Desktop/1mg/repos/PlantDoc-Dataset/data"
# base_dir = "/Users/anand/Desktop/temp/Data for Identification of Plant Leaf Diseases Using a 9-layer Deep Convolutional Neural Network/Plant_leave_diseases_dataset_without_augmentation"
base_dataset = datasets.ImageFolder(base_dir)

total_size = len(base_dataset)
print('Total size:', total_size)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

print('Train size:', train_size)
print('Val size:', val_size)
print('Test size:', test_size)

train_data, val_data, test_data = random_split(
    base_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

train_data.dataset.transform = data_transforms["train"]
val_data.dataset.transform = data_transforms["val"]
test_data.dataset.transform = data_transforms["test"]

batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_data, shuffle=True, pin_memory=True)


Total size: 2572
Train size: 2057
Val size: 257
Test size: 258


In [12]:
from collections import Counter

# These give you the label for each sample in the subset
train_labels = [label for _, label in train_data]
val_labels = [label for _, label in val_data]
test_labels = [label for _, label in test_data]

train_counters = Counter(train_labels)
val_counters =  Counter(val_labels)
test_counters = Counter(test_labels)

In [13]:
x = list(train_counters.keys())
x.sort()
print(x)

y = list(val_counters.keys())
y.sort()
print(y)

z = list(test_counters.keys())
z.sort()
print(z)


[0, 1]
[0, 1]
[0, 1]


In [8]:
# device = torch.device("mps" if torch.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print("Using device:", device)

Using device: mps


In [None]:
"""Quantisation step"""

import torch
import torchvision.models

# Check what backends are available
print("Available quantization engines:", torch.backends.quantized.supported_engines)

# Use the first available backend
if torch.backends.quantized.supported_engines:
    torch.backends.quantized.engine = torch.backends.quantized.supported_engines[0]
    print(f"Using quantization engine: {torch.backends.quantized.engine}")
else:
    print("No quantization engines available")

model = torchvision.models.mobilenet_v3_large(weights="DEFAULT")
model = model.to("cpu")
model.eval()

model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

Available quantization engines: ['qnnpack', 'none']
Using quantization engine: qnnpack


In [9]:
import torchvision
import torchvision.models
model = torchvision.models.mobilenet_v3_large(weights="DEFAULT")

# Get number of classes
num_classes = len(base_dataset.classes) 
print("Num classes: ", num_classes)

# Check classifier structure and get input features
print("Classifier:", model.classifier)

# For MobileNetV3, the final layer is usually at index -1 or index 3
# Let's try the last layer
final_layer = model.classifier[-1]  # or try model.classifier[3]
in_features = final_layer.in_features
print("No of features: ", in_features)

# Replace the final layer
model.classifier[-1] = nn.Linear(in_features, num_classes)

# Move to device
model = model.to(device)

# Freeze feature extraction layers
for param in model.features.parameters():
    param.requires_grad = False

# Setup loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Num classes:  2
Classifier: Sequential(
  (0): Linear(in_features=960, out_features=1280, bias=True)
  (1): Hardswish()
  (2): Dropout(p=0.2, inplace=True)
  (3): Linear(in_features=1280, out_features=1000, bias=True)
)
No of features:  1280


In [38]:
num_epochs = 30
print("device", device)
train("mobile_net_v3", num_epochs, model, train_data, train_loader, test_data, test_loader, criterion, optimizer, device)

device mps


Train: 100%|██████████| 694/694 [01:45<00:00,  6.56it/s]


===Train:	|	Accuracy: 0.9047	|	Loss: 0.3294


Val: 100%|██████████| 5546/5546 [00:54<00:00, 101.74it/s]


===Valid:	|	Accuracy: 0.9522	|	Loss: 0.1361





Train: 100%|██████████| 694/694 [01:43<00:00,  6.71it/s]


===Train:	|	Accuracy: 0.9596	|	Loss: 0.1181


Val: 100%|██████████| 5546/5546 [00:55<00:00, 100.07it/s]


===Valid:	|	Accuracy: 0.9620	|	Loss: 0.1167





Train: 100%|██████████| 694/694 [01:44<00:00,  6.63it/s]


===Train:	|	Accuracy: 0.9730	|	Loss: 0.0817


Val: 100%|██████████| 5546/5546 [00:55<00:00, 99.98it/s] 


===Valid:	|	Accuracy: 0.9609	|	Loss: 0.1150





Train: 100%|██████████| 694/694 [01:43<00:00,  6.68it/s]


===Train:	|	Accuracy: 0.9784	|	Loss: 0.0640


Val: 100%|██████████| 5546/5546 [00:54<00:00, 101.74it/s]


===Valid:	|	Accuracy: 0.9602	|	Loss: 0.1188





Train: 100%|██████████| 694/694 [01:45<00:00,  6.59it/s]


===Train:	|	Accuracy: 0.9837	|	Loss: 0.0478


Val: 100%|██████████| 5546/5546 [00:55<00:00, 100.22it/s]


===Valid:	|	Accuracy: 0.9639	|	Loss: 0.1129





Train: 100%|██████████| 694/694 [01:45<00:00,  6.55it/s]


===Train:	|	Accuracy: 0.9846	|	Loss: 0.0434


Val: 100%|██████████| 5546/5546 [00:52<00:00, 104.97it/s]


===Valid:	|	Accuracy: 0.9636	|	Loss: 0.1180





Train: 100%|██████████| 694/694 [01:43<00:00,  6.70it/s]


===Train:	|	Accuracy: 0.9867	|	Loss: 0.0380


Val: 100%|██████████| 5546/5546 [00:55<00:00, 99.88it/s] 


===Valid:	|	Accuracy: 0.9695	|	Loss: 0.1073





Train: 100%|██████████| 694/694 [01:44<00:00,  6.62it/s]


===Train:	|	Accuracy: 0.9853	|	Loss: 0.0435


Val: 100%|██████████| 5546/5546 [00:54<00:00, 101.80it/s]


===Valid:	|	Accuracy: 0.9693	|	Loss: 0.1160





Train: 100%|██████████| 694/694 [01:44<00:00,  6.62it/s]


===Train:	|	Accuracy: 0.9871	|	Loss: 0.0387


Val: 100%|██████████| 5546/5546 [00:53<00:00, 104.41it/s]


===Valid:	|	Accuracy: 0.9690	|	Loss: 0.1271





Train: 100%|██████████| 694/694 [01:42<00:00,  6.75it/s]


===Train:	|	Accuracy: 0.9903	|	Loss: 0.0288


Val: 100%|██████████| 5546/5546 [00:53<00:00, 103.70it/s]


===Valid:	|	Accuracy: 0.9679	|	Loss: 0.1324





Train: 100%|██████████| 694/694 [01:43<00:00,  6.73it/s]


===Train:	|	Accuracy: 0.9910	|	Loss: 0.0269


Val: 100%|██████████| 5546/5546 [00:53<00:00, 102.87it/s]


===Valid:	|	Accuracy: 0.9690	|	Loss: 0.1271





Train: 100%|██████████| 694/694 [01:41<00:00,  6.85it/s]


===Train:	|	Accuracy: 0.9908	|	Loss: 0.0306


Val: 100%|██████████| 5546/5546 [00:53<00:00, 103.93it/s]


===Valid:	|	Accuracy: 0.9663	|	Loss: 0.1233





Train: 100%|██████████| 694/694 [01:41<00:00,  6.81it/s]


===Train:	|	Accuracy: 0.9915	|	Loss: 0.0271


Val: 100%|██████████| 5546/5546 [00:53<00:00, 103.94it/s]


===Valid:	|	Accuracy: 0.9692	|	Loss: 0.1476





Train: 100%|██████████| 694/694 [01:44<00:00,  6.66it/s]


===Train:	|	Accuracy: 0.9910	|	Loss: 0.0288


Val: 100%|██████████| 5546/5546 [00:52<00:00, 104.73it/s]


===Valid:	|	Accuracy: 0.9638	|	Loss: 0.1697





Train: 100%|██████████| 694/694 [01:43<00:00,  6.69it/s]


===Train:	|	Accuracy: 0.9909	|	Loss: 0.0318


Val: 100%|██████████| 5546/5546 [00:55<00:00, 100.80it/s]


===Valid:	|	Accuracy: 0.9708	|	Loss: 0.1361





Train: 100%|██████████| 694/694 [01:42<00:00,  6.80it/s]


===Train:	|	Accuracy: 0.9914	|	Loss: 0.0270


Val: 100%|██████████| 5546/5546 [00:52<00:00, 104.74it/s]


===Valid:	|	Accuracy: 0.9701	|	Loss: 0.1396





Train: 100%|██████████| 694/694 [01:42<00:00,  6.74it/s]


===Train:	|	Accuracy: 0.9904	|	Loss: 0.0328


Val: 100%|██████████| 5546/5546 [00:53<00:00, 103.42it/s]


===Valid:	|	Accuracy: 0.9681	|	Loss: 0.1534





Train: 100%|██████████| 694/694 [01:41<00:00,  6.83it/s]


===Train:	|	Accuracy: 0.9938	|	Loss: 0.0204


Val: 100%|██████████| 5546/5546 [00:56<00:00, 98.94it/s] 


===Valid:	|	Accuracy: 0.9726	|	Loss: 0.1268





Train: 100%|██████████| 694/694 [01:43<00:00,  6.68it/s]


===Train:	|	Accuracy: 0.9913	|	Loss: 0.0308


Val: 100%|██████████| 5546/5546 [00:56<00:00, 97.90it/s] 


===Valid:	|	Accuracy: 0.9686	|	Loss: 0.1415





Train: 100%|██████████| 694/694 [01:46<00:00,  6.52it/s]


===Train:	|	Accuracy: 0.9937	|	Loss: 0.0207


Val: 100%|██████████| 5546/5546 [00:55<00:00, 99.88it/s] 


===Valid:	|	Accuracy: 0.9702	|	Loss: 0.1621





Train: 100%|██████████| 694/694 [01:45<00:00,  6.58it/s]


===Train:	|	Accuracy: 0.9917	|	Loss: 0.0300


Val: 100%|██████████| 5546/5546 [00:54<00:00, 101.41it/s]


===Valid:	|	Accuracy: 0.9719	|	Loss: 0.1527





Train: 100%|██████████| 694/694 [01:47<00:00,  6.46it/s]


===Train:	|	Accuracy: 0.9921	|	Loss: 0.0273


Val: 100%|██████████| 5546/5546 [00:55<00:00, 99.89it/s] 


===Valid:	|	Accuracy: 0.9665	|	Loss: 0.1986





Train: 100%|██████████| 694/694 [01:46<00:00,  6.49it/s]


===Train:	|	Accuracy: 0.9909	|	Loss: 0.0336


Val: 100%|██████████| 5546/5546 [00:54<00:00, 101.49it/s]


===Valid:	|	Accuracy: 0.9717	|	Loss: 0.1593





Train: 100%|██████████| 694/694 [01:45<00:00,  6.57it/s]


===Train:	|	Accuracy: 0.9920	|	Loss: 0.0306


Val: 100%|██████████| 5546/5546 [00:54<00:00, 102.65it/s]


===Valid:	|	Accuracy: 0.9674	|	Loss: 0.1733





Train: 100%|██████████| 694/694 [01:46<00:00,  6.53it/s]


===Train:	|	Accuracy: 0.9932	|	Loss: 0.0274


Val: 100%|██████████| 5546/5546 [00:54<00:00, 101.07it/s]


===Valid:	|	Accuracy: 0.9717	|	Loss: 0.1684





Train: 100%|██████████| 694/694 [01:48<00:00,  6.37it/s]


===Train:	|	Accuracy: 0.9936	|	Loss: 0.0231


Val: 100%|██████████| 5546/5546 [00:55<00:00, 100.03it/s]


===Valid:	|	Accuracy: 0.9704	|	Loss: 0.1688





Train: 100%|██████████| 694/694 [01:47<00:00,  6.44it/s]


===Train:	|	Accuracy: 0.9950	|	Loss: 0.0163


Val: 100%|██████████| 5546/5546 [00:55<00:00, 100.27it/s]


===Valid:	|	Accuracy: 0.9675	|	Loss: 0.1769





Train: 100%|██████████| 694/694 [01:44<00:00,  6.64it/s]


===Train:	|	Accuracy: 0.9940	|	Loss: 0.0207


Val: 100%|██████████| 5546/5546 [00:54<00:00, 102.35it/s]


===Valid:	|	Accuracy: 0.9749	|	Loss: 0.1550





Train: 100%|██████████| 694/694 [01:44<00:00,  6.62it/s]


===Train:	|	Accuracy: 0.9927	|	Loss: 0.0287


Val: 100%|██████████| 5546/5546 [00:56<00:00, 98.65it/s] 


===Valid:	|	Accuracy: 0.9749	|	Loss: 0.1483





Train: 100%|██████████| 694/694 [01:45<00:00,  6.59it/s]


===Train:	|	Accuracy: 0.9934	|	Loss: 0.0251


Val: 100%|██████████| 5546/5546 [00:54<00:00, 102.68it/s]


===Valid:	|	Accuracy: 0.9701	|	Loss: 0.1861



Best checkpoinpt: 27
Train Accuracy: 0.9949501780963975	|	Train Loss: 0.016309875207539112
Test Accuracy: 0.9675441759826903	|	Test Loss: 0.17686573445059442
✅ Accuracy: 0.9701 (5380/5546)
