<a href="https://colab.research.google.com/github/Denisganga/the_plant_doctor/blob/main/The_plant_doctor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [43]:

#importing the required modules
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')
import torchvision.transforms as transforms

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [44]:
from torchvision import datasets
from torch.utils.data import random_split
import numpy as np
import random

In [45]:
#unzipping my dataset
import zipfile

zip_file_path =  "/content/drive/My Drive/the_plant_doctor/archive.zip"
extract_path = "/content/dataset"

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
  zip_ref.extractall(extract_path)

In [46]:
# Define the data directory (the path where your unzipped dataset is located)
data_dir="/content/dataset"

In [47]:
#setup device agnostic code(using the GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [48]:
#define the transformations
data_transforms =transforms.Compose([
    #make the model more robust to differently oriented images.
    transforms.RandomRotation(degrees=15),

    #random cropping of images to create multiple views of the same image
    transforms.RandomCrop(size=(224,224), padding=10),

    #Apply random color transformations to the images to make the model more invariant to changes in lighting and color
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2,hue=0.2),

    #help the model focus on important features and reduce noise
    transforms.GaussianBlur(kernel_size=3),

    #resize and crop the image to create variations in the field of view.
    transforms.RandomResizedCrop(size=(224,224), scale=(0.8, 1.0)),

    #introduce controlled occlusions or "erasing" of parts of the image during training
    transforms.RandomErasing(),

    #Randomly flip the image horizontally
    transforms.RandomHorizontalFlip(),

    #Convert the image to a PyTorch tensor
    transforms.ToTensor(),

     # Normalize the image based on typical RGB mean and standard deviation
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])



])

In [49]:
#create a dataset using ImageFolder
dataset= datasets.ImageFolder(data_dir, transform=data_transforms)

In [50]:
# Define class labels based on the dataset structure
classes = dataset.classes

In [51]:
#splitting the dataset into training and testing sets
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset,[train_size, test_size])

In [52]:
print(len(dataset))
print(len(train_dataset))
print(len(test_dataset))

41276
33020
8256


In [75]:
image, label = train_dataset[0]
image, label

(tensor([[[1.6495, 1.6667, 1.6838,  ..., 0.7591, 0.8276, 0.7933],
          [1.6667, 1.6667, 1.6667,  ..., 0.7933, 0.8961, 0.9132],
          [1.6495, 1.6495, 1.6324,  ..., 0.8961, 1.0331, 1.0673],
          ...,
          [1.2214, 1.1015, 0.9817,  ..., 0.7762, 0.7762, 0.7933],
          [1.0331, 0.9303, 0.8961,  ..., 0.8276, 0.7419, 0.7248],
          [0.8789, 0.8618, 0.8961,  ..., 0.8618, 0.7419, 0.6221]],
 
         [[1.9384, 1.9559, 1.9734,  ..., 0.8354, 0.9230, 0.9230],
          [1.9559, 1.9734, 1.9559,  ..., 0.8529, 0.9755, 1.0280],
          [1.9384, 1.9384, 1.9384,  ..., 0.9580, 1.1155, 1.2031],
          ...,
          [1.1331, 0.9580, 0.8004,  ..., 1.0280, 1.0280, 1.0455],
          [0.8179, 0.7129, 0.6429,  ..., 1.0805, 0.9930, 0.9755],
          [0.5903, 0.5553, 0.5728,  ..., 1.1155, 0.9930, 0.8880]],
 
         [[2.2217, 2.2391, 2.2566,  ..., 0.5659, 0.7054, 0.7751],
          [2.2391, 2.2391, 2.2391,  ..., 0.6008, 0.7576, 0.8797],
          [2.2217, 2.2217, 2.2043,  ...,

In [76]:
image.shape

torch.Size([3, 224, 224])

In [87]:
# See classes
class_names = dataset.classes

class_names

['PlantVillage', 'plantvillage']

In [101]:

#contents of plantvillages and PlantVillages directories
import os

# Define the paths to the "PlantVillage" and "plantvillage" directories
plant_village_path = os.path.join(data_dir, "PlantVillage")
plant_village_lower_path = os.path.join(data_dir, "plantvillage")

# List the subdirectories within "PlantVillage"
plant_village_classes = os.listdir(plant_village_path)

# List the subdirectories within "plantvillage"
plant_village_lower_classes = os.listdir(plant_village_lower_path)

print("Classes in 'PlantVillage':", plant_village_classes)
print("Classes in 'plantvillage':", plant_village_lower_classes)


Classes in 'PlantVillage': ['Tomato_Early_blight', 'Tomato_Bacterial_spot', 'Tomato_Leaf_Mold', 'Potato___Early_blight', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato_healthy', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Pepper__bell___healthy', 'Tomato_Late_blight', 'Potato___healthy', 'Pepper__bell___Bacterial_spot', 'Tomato_Septoria_leaf_spot', 'Potato___Late_blight']
Classes in 'plantvillage': ['PlantVillage']


In [54]:
#create a Dataloader for efficient dataloading and batch
batch_size = 7
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
num_samples = len(dataset)

In [104]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import random

# Define the path to the "plantvillage" directory
plant_village_path = os.path.join(data_dir, "plantvillage", "train")

# List the subdirectories (class labels) within "plantvillage"
class_labels = os.listdir(plant_village_path)

# Create a 3x4 grid to display random images
num_rows = 3
num_cols = 4
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8))

for i, label in enumerate(class_labels):
    # Find a random image within the class directory
    class_dir = os.path.join(plant_village_path, label)
    image_files = os.listdir(class_dir)
    random_image_file = os.path.join(class_dir, random.choice(image_files))
    random_image = Image.open(random_image_file)

    # Plot the image
    row = i // num_cols
    col = i % num_cols
    axes[row, col].imshow(random_image)
    axes[row, col].set_title(label)
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()


FileNotFoundError: ignored