In [67]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [68]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [69]:
IMAGE_SIZE = 256
BATCH_SIZE = 32
dataset_path = './PlantVillage-20240329T093356Z-001/PlantVillage'
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
class_names = dataset.classes

In [70]:
class PlantDiseaseModel(nn.Module):
    def __init__(self, num_classes=3, init_image_size=256):
        super(PlantDiseaseModel, self).__init__()
        self.init_image_size = init_image_size 
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        feature_size = self.init_image_size // (2 ** 6)  
        linear_input_features = 64 * (feature_size ** 2)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(linear_input_features, 64),  
            nn.ReLU(inplace=True),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [71]:
import torch

# Define the path to the checkpoint file
checkpoint_file = 'model_checkpoint_final.pth'

# Load the checkpoint
checkpoint = torch.load(checkpoint_file)

# Initialize the model (assuming PlantDiseaseCNN class is defined)
model = PlantDiseaseModel(num_classes=10)  # Adjust num_classes if necessary

# Load the model's state dictionary from the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])

# Print the model's parameters
for name, param in model.named_parameters():
    print(f'Parameter name: {name}')
    print(f'Parameter value: {param}')



Parameter name: features.0.weight
Parameter value: Parameter containing:
tensor([[[[ 0.0074,  0.1840, -0.1879],
          [-0.0432,  0.0402,  0.0096],
          [ 0.2253, -0.1163,  0.0675]],

         [[-0.0872, -0.0725,  0.1476],
          [ 0.0430,  0.1140,  0.0375],
          [-0.1750, -0.0426, -0.0379]],

         [[-0.1131, -0.2161, -0.0697],
          [ 0.2227,  0.1050, -0.0745],
          [-0.0834,  0.0270,  0.1657]]],


        [[[ 0.0286,  0.1157,  0.2601],
          [ 0.1231, -0.0510, -0.0484],
          [-0.1303, -0.0833,  0.0710]],

         [[ 0.1670,  0.1816, -0.0098],
          [ 0.1413,  0.2285,  0.1256],
          [ 0.0832,  0.1628, -0.0253]],

         [[ 0.0604, -0.2223, -0.1992],
          [-0.1836, -0.0451, -0.0732],
          [-0.2005, -0.1016, -0.1857]]],


        [[[-0.0150, -0.0512, -0.0635],
          [-0.0558, -0.1459, -0.0356],
          [-0.0075, -0.0246, -0.1054]],

         [[ 0.1523,  0.2525, -0.0399],
          [ 0.2523,  0.0359, -0.0438],
          [-

In [72]:
import torch

# Define the path to the checkpoint file
checkpoint_file = 'model_checkpoint_final.pth'

# Load the checkpoint
checkpoint = torch.load(checkpoint_file)

# Retrieve epoch_loss and epoch_acc from the checkpoint dictionary
test_loss = checkpoint['epoch_loss']
accuracy = checkpoint['epoch_acc']

# Print epoch_loss and epoch_acc
print("Epoch Loss:", test_loss)
print("Epoch Accuracy:", accuracy)

Epoch Loss: 0.07739106259495933
Epoch Accuracy: 97.6932668329177


In [73]:
device = torch.device("cuda")

In [74]:
def predict_pytorch(model, img, class_names, transform):
    # img is expected to be a PIL Image
    # Apply the same transform as during training/testing
    img_transformed = transform(img).unsqueeze(0)  # Add batch dimension

    # Move the image tensor to the same device as the model
    device = next(model.parameters()).device
    img_transformed = img_transformed.to(device)

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        outputs = model(img_transformed)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
        predicted_class = class_names[predicted.item()]
        confidence = round(100 * confidence.item(), 2)

    return predicted_class, confidence

In [75]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image
import io

# Ensure the model is in evaluation mode and moved to the appropriate device
model.eval()
model.to(device)

# Define the image upload button
upload = widgets.FileUpload()
button = widgets.Button(description="Predict")
output = widgets.Output()

def on_button_clicked(b):
    # Display the uploaded image and prediction
    with output:
        clear_output()
        if upload.value:
            # Get the uploaded file
            input_image = upload.value[0]
            content = input_image['content']
            img = Image.open(io.BytesIO(content))
            
            # Display the uploaded image
            display(img)
            
            # Predict the class of the uploaded image
            predicted_class, confidence = predict_pytorch(model, img, class_names, transform)
            
            # Display the prediction and confidence
            print(f"Predicted Class: {predicted_class}, Confidence: {confidence}%")
        else:
            print("No image uploaded. Please upload an image of a diseased plant leaf.")

# Link the button to the prediction function
button.on_click(on_button_clicked)

# Display the widgets
display(upload, button, output)


FileUpload(value=(), description='Upload')

Button(description='Predict', style=ButtonStyle())

Output()