In [1]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
#model: MobileNet v3 Small
model = models.mobilenet_v3_small(weights="IMAGENET1K_V1")  
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 4)  
model.to(device);

Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth
100%|██████████| 9.83M/9.83M [00:00<00:00, 95.0MB/s]


In [4]:
data_dir = "/kaggle/input/ctn-dataset/cotton"  
classes = ['bacterial_blight', 'curl_virus', 'fussarium_wilt', 'healthy'] 

In [5]:
# data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [6]:
dataset = ImageFolder(data_dir, transform=transform)

# Split the dataset into training, validation, and test 
train_ratio = 0.7  
val_ratio = 0.2   
test_ratio = 0.1   

train_size = int(train_ratio * len(dataset))
val_size = int(val_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [7]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
#loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:
#training 
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total

        # Validation 
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")

train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

Epoch 1/10, Train Loss: 0.3402, Train Acc: 0.8872, Val Loss: 0.2737, Val Acc: 0.9269
Epoch 2/10, Train Loss: 0.0620, Train Acc: 0.9858, Val Loss: 0.0820, Val Acc: 0.9766
Epoch 3/10, Train Loss: 0.0907, Train Acc: 0.9758, Val Loss: 0.2094, Val Acc: 0.9298
Epoch 4/10, Train Loss: 0.0785, Train Acc: 0.9741, Val Loss: 0.2278, Val Acc: 0.9357
Epoch 5/10, Train Loss: 0.0285, Train Acc: 0.9933, Val Loss: 0.0188, Val Acc: 0.9971
Epoch 6/10, Train Loss: 0.0113, Train Acc: 0.9975, Val Loss: 0.0559, Val Acc: 0.9912
Epoch 7/10, Train Loss: 0.0140, Train Acc: 0.9950, Val Loss: 0.0792, Val Acc: 0.9708
Epoch 8/10, Train Loss: 0.0114, Train Acc: 0.9967, Val Loss: 0.0207, Val Acc: 0.9883
Epoch 9/10, Train Loss: 0.0234, Train Acc: 0.9916, Val Loss: 0.2551, Val Acc: 0.9386
Epoch 10/10, Train Loss: 0.0183, Train Acc: 0.9933, Val Loss: 0.1188, Val Acc: 0.9649


In [10]:
#evaluating the model
def evaluate_model(model, loader, dataset_type):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    print(f"Evaluation on {dataset_type} Set:")
    print(classification_report(all_labels, all_preds, target_names=classes))

evaluate_model(model, val_loader, "Validation")
evaluate_model(model, test_loader, "Test")

Evaluation on Validation Set:
                  precision    recall  f1-score   support

bacterial_blight       0.89      1.00      0.94        88
      curl_virus       1.00      0.94      0.97        78
  fussarium_wilt       0.99      1.00      0.99        97
         healthy       1.00      0.91      0.95        79

        accuracy                           0.96       342
       macro avg       0.97      0.96      0.96       342
    weighted avg       0.97      0.96      0.97       342

Evaluation on Test Set:
                  precision    recall  f1-score   support

bacterial_blight       0.96      1.00      0.98        43
      curl_virus       1.00      0.98      0.99        45
  fussarium_wilt       0.98      1.00      0.99        41
         healthy       1.00      0.95      0.98        42

        accuracy                           0.98       171
       macro avg       0.98      0.98      0.98       171
    weighted avg       0.98      0.98      0.98       171



In [11]:
torch.save(model.state_dict(), "mobilenet_v3_model.pth")
print("Model saved successfully!")

Model saved successfully!


# Link to pth file: https://drive.google.com/file/d/1xxoZhQyz3-PaAq2QCL0y2_Cnm5qPRBwF/view?usp=sharing