In [1]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split 
from torchvision.datasets import ImageFolder
from sklearn.metrics import classification_report
from PIL import Image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# model:DenseNet-121 
model = models.densenet121(weights="IMAGENET1K_V1")  
model.classifier = nn.Linear(model.classifier.in_features, 4)  
model.to(device);

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 116MB/s] 


In [4]:
data_dir = "/kaggle/input/dataset-cotton/cotton" 
classes = ['bacterial_blight', 'curl_virus', 'fussarium_wilt', 'healthy']  

In [5]:
# data transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [6]:
dataset = ImageFolder(data_dir, transform=transform)

#Split the dataset into training, validation, and test 
train_ratio = 0.7  
val_ratio = 0.2    
test_ratio = 0.1   

train_size = int(train_ratio * len(dataset))
val_size = int(val_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [7]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
# optimizer and loss func
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:
#training
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total

        # Validation 
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")

train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

Epoch 1/10, Train Loss: 0.3583, Train Acc: 0.8647, Val Loss: 0.1995, Val Acc: 0.9327
Epoch 2/10, Train Loss: 0.1828, Train Acc: 0.9365, Val Loss: 0.1332, Val Acc: 0.9444
Epoch 3/10, Train Loss: 0.1253, Train Acc: 0.9624, Val Loss: 0.0485, Val Acc: 0.9766
Epoch 4/10, Train Loss: 0.0914, Train Acc: 0.9657, Val Loss: 0.1692, Val Acc: 0.9327
Epoch 5/10, Train Loss: 0.0515, Train Acc: 0.9825, Val Loss: 0.3524, Val Acc: 0.9240
Epoch 6/10, Train Loss: 0.0704, Train Acc: 0.9791, Val Loss: 0.0883, Val Acc: 0.9766
Epoch 7/10, Train Loss: 0.0388, Train Acc: 0.9858, Val Loss: 0.0429, Val Acc: 0.9883
Epoch 8/10, Train Loss: 0.0711, Train Acc: 0.9799, Val Loss: 0.1722, Val Acc: 0.9561
Epoch 9/10, Train Loss: 0.1246, Train Acc: 0.9607, Val Loss: 0.1565, Val Acc: 0.9503
Epoch 10/10, Train Loss: 0.0657, Train Acc: 0.9808, Val Loss: 0.1066, Val Acc: 0.9737


In [10]:
#Evaluating the model on valid and test data
def evaluate_model(model, loader, dataset_type):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    print(f"Evaluation on {dataset_type} Set:")
    print(classification_report(all_labels, all_preds, target_names=classes))

evaluate_model(model, val_loader, "Validation")
evaluate_model(model, test_loader, "Test")

Evaluation on Validation Set:
                  precision    recall  f1-score   support

bacterial_blight       0.99      0.98      0.98        91
      curl_virus       0.98      0.97      0.97        88
  fussarium_wilt       0.98      0.98      0.98        84
         healthy       0.95      0.97      0.96        79

        accuracy                           0.97       342
       macro avg       0.97      0.97      0.97       342
    weighted avg       0.97      0.97      0.97       342

Evaluation on Test Set:
                  precision    recall  f1-score   support

bacterial_blight       1.00      0.96      0.98        52
      curl_virus       1.00      0.91      0.95        45
  fussarium_wilt       0.92      0.97      0.94        34
         healthy       0.91      1.00      0.95        40

        accuracy                           0.96       171
       macro avg       0.96      0.96      0.96       171
    weighted avg       0.96      0.96      0.96       171



In [11]:
#to save the model
torch.save(model.state_dict(), "densenet121_model.pth")
print("Model saved successfully!")

Model saved successfully!


# Link for pth file: https://drive.google.com/file/d/13IDt6BdCcqs6B7IxpOWuoW6I3KmDl_0E/view?usp=sharing