In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("vipoooool/new-plant-diseases-dataset")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: /kaggle/input/new-plant-diseases-dataset


In [1]:
# pd2se_train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

# -----------------------------
# 1. Data Transforms & Loaders
# -----------------------------
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

valid_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder('/kaggle/input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train', transform=train_transforms)
valid_dataset = datasets.ImageFolder('/kaggle/input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/valid', transform=valid_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)

num_classes = len(train_dataset.classes)

# -----------------------------
# 2. Define PD^2SE-like Model
# -----------------------------
class PD2SENet(nn.Module):
    def __init__(self, num_classes):
        super(PD2SENet, self).__init__()
        # Pretrained ResNet50 backbone
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()  # remove default classifier

        # Main classification head
        self.fc_main = nn.Linear(2048, num_classes)

        # Optional: auxiliary heads (example: severity)
        # self.fc_severity = nn.Linear(2048, 3)  # mild/moderate/severe

    def forward(self, x):
        features = self.backbone(x)
        out_main = self.fc_main(features)
        # out_severity = self.fc_severity(features)
        return out_main  # , out_severity

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PD2SENet(num_classes).to(device)

# -----------------------------
# 3. Loss and Optimizer
# -----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# -----------------------------
# 4. Training Loop
# -----------------------------
num_epochs = 5  # adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Acc: {epoch_acc:.4f}')

    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    val_acc = val_correct / val_total
    print(f'Validation Acc: {val_acc:.4f}')

# -----------------------------
# 5. Save Model
# -----------------------------
torch.save(model.state_dict(), 'pd2se_model.pth')
print("Model saved!")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 210MB/s]


Epoch 1/5 - Loss: 0.1650 - Acc: 0.9577
Validation Acc: 0.9878
Epoch 2/5 - Loss: 0.0499 - Acc: 0.9846
Validation Acc: 0.9899
Epoch 3/5 - Loss: 0.0371 - Acc: 0.9889
Validation Acc: 0.9927
Epoch 4/5 - Loss: 0.0292 - Acc: 0.9910
Validation Acc: 0.9949
Epoch 5/5 - Loss: 0.0264 - Acc: 0.9919
Validation Acc: 0.9913
Model saved!


In [20]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os

# -----------------------------
# SETTINGS
# -----------------------------
MODEL_PATH = "pd2se_model.pth"  # Your trained PyTorch model
IMAGE_SIZE = (224, 224)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# DISEASE-TO-TREATMENT WITH DOSAGE
# -----------------------------
treatment_dict = {
    "Apple___Apple_scab": "Remove infected leaves/fruit, apply Captan fungicide; 2 g per liter of water.",
    "Apple___Black_rot": "Remove infected leaves/fruit, apply Captan fungicide; 2 g per liter of water.",
    "Apple___Cedar_apple_rust": "Remove infected leaves, apply fungicide; 2 g per liter of water.",
    "Apple___healthy": "No treatment needed. Keep monitoring the plant.",
    "Blueberry___healthy": "No treatment needed.",
    "Cherry_(including_sour)___Powdery_mildew": "Apply sulfur-based fungicide; 2 g per liter of water.",
    "Cherry_(including_sour)___healthy": "No treatment needed.",
    "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot": "Apply fungicide (e.g., azoxystrobin); 1.5 ml per liter of water.",
    "Corn_(maize)___Common_rust_": "Apply azoxystrobin fungicide; 1.5 ml per liter of water, use resistant hybrids.",
    "Corn_(maize)___Northern_Leaf_Blight": "Apply mancozeb; 2 g per liter of water.",
    "Corn_(maize)___healthy": "No treatment needed.",
    "Potato___Early_blight": "Use chlorothalonil-based fungicide; 2 ml per liter of water, remove debris.",
    "Potato___Late_blight": "Apply mancozeb or metalaxyl fungicides; 2–3 g per liter, destroy infected plants.",
    "Potato___healthy": "No treatment required.",
    "Tomato___Leaf_Mold": "Improve ventilation, apply copper-based fungicide; 2 ml per liter of water.",
    "Tomato___healthy": "No treatment needed.",
    "Strawberry___Leaf_scorch": "Apply copper-based fungicide; 2 ml per liter of water.",
    "Strawberry___healthy": "No treatment needed.",
    # Add all other classes as needed
}

# -----------------------------
# SEVERITY ESTIMATION
# -----------------------------
def estimate_severity(probability):
    if probability < 0.60:
        return "Mild"
    elif probability < 0.85:
        return "Moderate"
    else:
        return "Severe"

# -----------------------------
# LOAD MODEL
# -----------------------------
class PD2SENet(nn.Module):
    def __init__(self, num_classes):
        super(PD2SENet, self).__init__()
        self.backbone = torch.hub.load('pytorch/vision:v0.11.3', 'resnet50', pretrained=False)
        self.backbone.fc = nn.Identity()
        self.fc_main = nn.Linear(2048, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        out_main = self.fc_main(features)
        return out_main

# Number of classes in your dataset
NUM_CLASSES = 38
model = PD2SENet(NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# -----------------------------
# CLASS INDICES
# -----------------------------
class_indices = {
    0: "Apple___Apple_scab",
    1: "Apple___Black_rot",
    2: "Apple___Cedar_apple_rust",
    3: "Apple___healthy",
    4: "Blueberry___healthy",
    5: "Cherry_(including_sour)___Powdery_mildew",
    6: "Cherry_(including_sour)___healthy",
    7: "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
    8: "Corn_(maize)___Common_rust_",
    9: "Corn_(maize)___Northern_Leaf_Blight",
    10: "Corn_(maize)___healthy",
    11: "Grape___Black_rot",
    12: "Grape___Esca_(Black_Measles)",
    13: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
    14: "Grape___healthy",
    15: "Orange___Haunglongbing_(Citrus_greening)",
    16: "Peach___Bacterial_spot",
    17: "Peach___healthy",
    18: "Pepper,_bell___Bacterial_spot",
    19: "Pepper,_bell___healthy",
    20: "Potato___Early_blight",
    21: "Potato___Late_blight",
    22: "Potato___healthy",
    23: "Raspberry___healthy",
    24: "Soybean___healthy",
    25: "Squash___Powdery_mildew",
    26: "Strawberry___Leaf_scorch",
    27: "Strawberry___healthy",
    28: "Tomato___Bacterial_spot",
    29: "Tomato___Early_blight",
    30: "Tomato___Late_blight",
    31: "Tomato___Leaf_Mold",
    32: "Tomato___Septoria_leaf_spot",
    33: "Tomato___Spider_mites Two-spotted_spider_mite",
    34: "Tomato___Target_Spot",
    35: "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
    36: "Tomato___Tomato_mosaic_virus",
    37: "Tomato___healthy"
}

# -----------------------------
# IMAGE PREPROCESS
# -----------------------------
preprocess = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# -----------------------------
# PREDICTION FUNCTION
# -----------------------------
def predict_disease(img_path):
    img = Image.open(img_path).convert('RGB')
    img_tensor = preprocess(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probs, 1)
        predicted_idx = predicted_idx.item()
        confidence = confidence.item()

    disease_name = class_indices[predicted_idx]
    severity = estimate_severity(confidence)
    treatment = treatment_dict.get(disease_name, "No treatment info available.")

    return f"Disease: {disease_name.replace('_',' ')} | Infection Level: {severity} | Treatment: {treatment}"

# -----------------------------
# TEST
# -----------------------------
test_image_path = "/kaggle/input/sample4/leaf-rust-fungicide-crop-protection.jpg"
if os.path.exists(test_image_path):
    print(predict_disease(test_image_path))
else:
    print("⚠️ Please provide a valid test image path.")


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.11.3


Disease: Corn (maize)   Northern Leaf Blight | Infection Level: Mild | Treatment: Apply mancozeb; 2 g per liter of water.
