In [1]:
%cd ..

e:\AI_projects\plant_disease


## Import Libraries

In [2]:
import os
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
device = torch.device("cuda")

In [4]:
batch_size = 256
epochs = 1

img_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [5]:
train_ds = datasets.ImageFolder(root=r"data\New Plant Diseases Dataset(Augmented)\train", transform=img_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

In [6]:
val_ds = datasets.ImageFolder(root=r"data\New Plant Diseases Dataset(Augmented)\valid", transform=img_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

In [7]:
len(train_ds.classes)

38

In [8]:
train_ds.classes

['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_

## Load Model

In [9]:
model = models.resnet18(pretrained=True)



In [10]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
model.fc = nn.Linear(512, 38)

In [12]:
model.fc

Linear(in_features=512, out_features=38, bias=True)

In [13]:
model = model.to(device)

In [14]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [15]:
for e in range(epochs):
    model.train()
     
    total_loss = 0
    for batch in tqdm(train_loader, total=len(train_loader)):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
         
        optimizer.zero_grad()
        outs = model(imgs)
        loss_val = loss_func(outs, labels)
        loss_val.backward()
        optimizer.step()
         
        total_loss += loss_val.item()
         
    print(f"loss for epoch {e+1} = ", total_loss)

  0%|          | 0/275 [00:00<?, ?it/s]

100%|██████████| 275/275 [1:27:19<00:00, 19.05s/it]

loss for epoch 1 =  648.9544268846512





In [18]:
model_save_path = os.path.join("model", "plant_disease_model.pth")
torch.save(model, "model/plant_disease_model.pth")
print(f"Model saved to: {model_save_path}")

Model saved to: model\plant_disease_model.pth
