In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet50, ResNet50_Weights
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Dataset directory
train_dir = './predicted_masks'  # Adjust path as needed

# Dynamically determine class names
class_names = sorted([
    d for d in os.listdir(train_dir)
    if os.path.isdir(os.path.join(train_dir, d))
])
num_classes = len(class_names)
print(f"Found {num_classes} classes: {class_names}")

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Dataset and DataLoader
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# Load pretrained ResNet50 with updated API
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Replace final FC layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

# Move model to device
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct_preds += torch.sum(preds == labels).item()
        total_preds += labels.size(0)

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct_preds / total_preds
    print(f"Epoch {epoch+1} Completed | Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

# Save model
torch.save(model.state_dict(), 'plant_disease_resnet50.pth')
print("Model saved as plant_disease_resnet50.pth")


Using device: cuda
Found 38 classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spot

100%|█████████████████████████████████████| 97.8M/97.8M [00:06<00:00, 16.8MB/s]


Epoch 1, Batch 0/2194, Loss: 3.6414
Epoch 1, Batch 10/2194, Loss: 3.3626
Epoch 1, Batch 20/2194, Loss: 2.8919
Epoch 1, Batch 30/2194, Loss: 3.0167
Epoch 1, Batch 40/2194, Loss: 2.9204
Epoch 1, Batch 50/2194, Loss: 2.7246
Epoch 1, Batch 60/2194, Loss: 2.7415
Epoch 1, Batch 70/2194, Loss: 2.5306
Epoch 1, Batch 80/2194, Loss: 2.5238
Epoch 1, Batch 90/2194, Loss: 2.7465
Epoch 1, Batch 100/2194, Loss: 2.6042
Epoch 1, Batch 110/2194, Loss: 3.0067
Epoch 1, Batch 120/2194, Loss: 2.6088
Epoch 1, Batch 130/2194, Loss: 2.4339
Epoch 1, Batch 140/2194, Loss: 2.5476
Epoch 1, Batch 150/2194, Loss: 2.1752
Epoch 1, Batch 160/2194, Loss: 2.3020
Epoch 1, Batch 170/2194, Loss: 2.6167
