In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import warnings
import os
warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

data_root = r'C:\Users\Admin\Desktop\pytorch\torch_env\PlantVillage'

Using device: cpu


In [None]:
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

In [None]:
dataset = datasets.ImageFolder(
    root=data_root,
    transform=transform
)

print("Total Images:", len(dataset))
print("Classes:", dataset.classes)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_data, test_data = random_split(dataset,[train_size,test_size])

train_loader = DataLoader(train_data,batch_size=64,shuffle=True,num_workers=2)
test_loader = DataLoader(test_data,batch_size=64,shuffle=False,num_workers=2)


Total Images: 20638
Classes: ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']


In [None]:
model = models.resnet18(weights="IMAGENET1K_V1")

# replace classifier layer
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, len(dataset.classes))

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\Admin/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100.0%


In [8]:
# -----------------------------
# 9. Training Loop
# -----------------------------
epochs = 5

for epoch in range(epochs):

    model.train()
    running_loss = 0

    for images,labels in train_loader:

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs,labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} Loss:{running_loss:.4f}")

KeyboardInterrupt: 

In [None]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for images,labels in test_loader:

        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _,predicted = torch.max(outputs,1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print("Test Accuracy:", accuracy)

In [None]:
# -----------------------------
# 11. Save Model
# -----------------------------
torch.save(model.state_dict(),"plant_disease_model.pth")
print("Model saved")