In [3]:
# Batch 1: Basic setup and imports
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision.models import efficientnet_b3
import wandb
from tqdm import tqdm


In [2]:
# Batch 2: Initialize Weights & Biases
wandb.init(
    project="cotton-disease-classification",
    name="EfficientNetB3-run",
    config={
        "epochs": 15,
        "batch_size": 16,
        "learning_rate": 1e-4,
        "optimizer": "Adam",
        "architecture": "EfficientNet-B3",
        "image_size": 300,
        "num_classes": 7
    }
)

config = wandb.config


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: akashbk0037 to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


In [None]:
# Batch 3: Data transforms and loading
from sklearn.model_selection import train_test_split

# Dataset directory
dataset_dir = r"D:\Pragramming\AgroCare IIT Madras\SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection\SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection\Original Dataset\Original Dataset"

# Transformations
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                         std=[0.229, 0.224, 0.225])
])

# Load dataset
full_dataset = ImageFolder(root=dataset_dir, transform=transform)

# Split dataset (80% train, 20% val)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)

# Log class names
class_names = full_dataset.classes
wandb.config.update({"class_names": class_names})


In [5]:
# Batch 4: Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained EfficientNet-B3
model = efficientnet_b3(weights="IMAGENET1K_V1")

# Replace classifier (final FC layer) to match 7 classes
model.classifier[1] = nn.Linear(model.classifier[1].in_features, config.num_classes)

model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)


In [6]:
# Batch 5: Training & Validation
best_val_acc = 0.0

for epoch in range(config.epochs):
    model.train()
    train_loss, train_correct = 0.0, 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        # Forward
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Metrics
        _, preds = torch.max(outputs, 1)
        train_correct += (preds == labels).sum().item()
        train_loss += loss.item() * images.size(0)

        loop.set_postfix(loss=loss.item())

    train_acc = train_correct / len(train_loader.dataset)
    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss, val_correct = 0.0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_loss += loss.item() * images.size(0)

    val_acc = val_correct / len(val_loader.dataset)
    val_loss /= len(val_loader.dataset)

    # Log to wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_accuracy": val_acc
    })

    print(f"[Epoch {epoch+1}] Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        wandb.run.summary["best_val_accuracy"] = best_val_acc


Epoch 1/15: 100%|██████████| 1710/1710 [02:12<00:00, 12.92it/s, loss=2.04]  


[Epoch 1] Train Acc: 0.4515, Val Acc: 0.6051


Epoch 2/15: 100%|██████████| 1710/1710 [02:12<00:00, 12.91it/s, loss=0.0246] 


[Epoch 2] Train Acc: 0.8316, Val Acc: 0.6425


Epoch 3/15: 100%|██████████| 1710/1710 [02:14<00:00, 12.72it/s, loss=1.29]    


[Epoch 3] Train Acc: 0.9450, Val Acc: 0.6963


Epoch 4/15: 100%|██████████| 1710/1710 [02:17<00:00, 12.44it/s, loss=0.461]   


[Epoch 4] Train Acc: 0.9778, Val Acc: 0.6308


Epoch 5/15: 100%|██████████| 1710/1710 [02:11<00:00, 12.97it/s, loss=0.00209] 


[Epoch 5] Train Acc: 0.9860, Val Acc: 0.4486


Epoch 6/15: 100%|██████████| 1710/1710 [02:19<00:00, 12.26it/s, loss=0.00335] 


[Epoch 6] Train Acc: 0.9772, Val Acc: 0.4416


Epoch 7/15: 100%|██████████| 1710/1710 [02:20<00:00, 12.16it/s, loss=0.000167]


[Epoch 7] Train Acc: 0.9889, Val Acc: 0.5724


Epoch 8/15: 100%|██████████| 1710/1710 [02:20<00:00, 12.21it/s, loss=0.000228]


[Epoch 8] Train Acc: 0.9942, Val Acc: 0.6145


Epoch 9/15: 100%|██████████| 1710/1710 [02:13<00:00, 12.85it/s, loss=0.000211]


[Epoch 9] Train Acc: 0.9784, Val Acc: 0.5911


Epoch 10/15: 100%|██████████| 1710/1710 [02:23<00:00, 11.91it/s, loss=0.000645]


[Epoch 10] Train Acc: 0.9889, Val Acc: 0.5888


Epoch 11/15: 100%|██████████| 1710/1710 [02:25<00:00, 11.73it/s, loss=0.00035] 


[Epoch 11] Train Acc: 0.9854, Val Acc: 0.6028


Epoch 12/15: 100%|██████████| 1710/1710 [02:25<00:00, 11.72it/s, loss=0.000103]


[Epoch 12] Train Acc: 0.9953, Val Acc: 0.5888


Epoch 13/15: 100%|██████████| 1710/1710 [02:26<00:00, 11.64it/s, loss=0.000389]


[Epoch 13] Train Acc: 0.9936, Val Acc: 0.5350


Epoch 14/15: 100%|██████████| 1710/1710 [02:21<00:00, 12.07it/s, loss=0.00486] 


[Epoch 14] Train Acc: 0.9830, Val Acc: 0.5724


Epoch 15/15: 100%|██████████| 1710/1710 [02:18<00:00, 12.37it/s, loss=0.00149] 


[Epoch 15] Train Acc: 0.9942, Val Acc: 0.5374


In [7]:
# Load best model for evaluation or inference
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
            (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActiv