# This code is for providing the data augmentation by using the YOLOv8 Architecture.

Input : 
1. Directory Names of content images. [Healthy Leaf Directory of Different Plants]
2. Directory Names of style images. [Hand Picked Infected Leaf Images of Different Plants]

Output :
1. Directory for all the style images where the augmented will be saved. [14 directories in case of segmented PlantVillage Dataset]
2. This output will be used by the Augmentation Validation Classifier to include only relevent imgaes in the Actual Dataset.

Functioning of this code : 
1. Initially we will find out the disparity in the dataset depending on the number of images present in the each class.
2. Find mean number of images that should be present in each class and export this value for all the classes available. This value will be used to select top images from the augmented dataset.
3. Then we will initiate the style transfer using the hand-picked styles from each class on the healty image datast for each crop and save the images in respective folder.

# Loading the Model

In [41]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
from torch import optim
from ultralytics import YOLO
import matplotlib.pyplot as plt
from torchvision import transforms as T
from scipy.spatial.distance import euclidean
from skimage.feature import graycomatrix, graycoprops

device: str = "mps" if torch.backends.mps.is_available() else "cpu"

model = YOLO('./models/yolov8x-seg.pt')
yolo_model = model.model.model
for parameters in yolo_model:
    yolo_model.requires_grad_(False)

<class 'ultralytics.nn.tasks.SegmentationModel'>


In [42]:
i = 0
model_layers = {}
for name, layer in model._modules.items():
    for name_l, layer_l in layer._modules.items():
        for name_ll, layer_ll in layer_l._modules.items():
            model_layers[str(i)] = layer_ll
            i += 1
yolo_model.to(device)

Sequential(
  (0): Conv(
    (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (1): Conv(
    (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (2): C2f(
    (cv1): Conv(
      (conv): Conv2d(160, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (cv2): Conv(
      (conv): Conv2d(400, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (m): ModuleList(
      (0-2): 3 x Bottleneck(
        (cv1): 

# Augmentation code starts from here.

In [43]:
content_images_dir_list = [
    '../data/PlantVillage/train/Apple___healthy',
    '../data/PlantVillage/train/Corn_(maize)___healthy',
    '../data/PlantVillage/train/Grape___healthy',
    '../data/PlantVillage/train/Pepper,_bell___healthy',
    '../data/PlantVillage/train/Strawberry___healthy',
    '../data/PlantVillage/train/Tomato___healthy'
]

style_images_apple_dir_list = {
    '../data/PlantVillage/train/Apple___Apple_scab': 100,
    '../data/PlantVillage/train/Apple___Black_rot': 100,
    '../data/PlantVillage/train/Apple___Cedar_apple_rust': 100,
}

style_images_corn_dir_list = {
    '../data/PlantVillage/train/Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': 100,
    '../data/PlantVillage/train/Corn_(maize)___Common_rust_': 100,
    'data/PlantVillage/train/Corn_(maize)___Northern_Leaf_Blight': 100, 
}

style_images_grape_dir_list = {
    '../data/PlantVillage/train/Grape___Black_rot':100,
    '../data/PlantVillage/train/Grape___Esca_(Black_Measles)': 100,
    '../data/PlantVillage/train/Grape___Leaf_blight_(Isariopsis_Leaf_Spot)':100,
}

style_images_pepper_dir_list = {
    '../data/PlantVillage/train/Pepper,_bell___Bacterial_spot': 100,
}

style_images_strawberry_dir_list = {
    '../data/PlantVillage/train/Strawberry___Leaf_scorch': 100,
}

style_images_tomato_dir_list = {
    '../data/PlantVillage/train/Tomato___Bacterial_spot': 100,
    '../data/PlantVillage/train/Tomato___Early_blight': 100, 
    '../data/PlantVillage/train/Tomato___Late_blight': 100, 
    '../data/PlantVillage/train/Tomato___Leaf_Mold': 100,
    '../data/PlantVillage/train/Tomato___Septoria_leaf_spot': 100,
    '../data/PlantVillage/train/Tomato___Spider_mites Two-spotted_spider_mite': 100,
    '../data/PlantVillage/train/Tomato___Target_Spot': 100,
    '../data/PlantVillage/train/Tomato___Tomato_mosaic_virus': 100,
    '../data/PlantVillage/train/Tomato___Tomato_Yellow_Leaf_Curl_Virus': 100,
}

In [44]:
def preprocess(img_path, max_size = 640):
  image = Image.open(img_path).convert('RGB')
  img_transforms = T.Compose([
      T.ToTensor(),  # (224, 224, 3) -> (3, 224, 224)
      T.Normalize(mean = [0.485, 0.456, 0.406],
                  std = [0.229, 0.224, 0.225])
  ])
  image = img_transforms(image)
  image = image.unsqueeze(0) # (3, 224, 224) -> (1, 3, 224, 224)
  return image

def deprocess(tensor):
  image = tensor.to('cpu').clone()
  image = image.numpy()
  image = image.squeeze(0)
  image = image.transpose(1, 2, 0)
  # denormalizing the image
  image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
  image = image.clip(0, 1)
  return image

def get_features(image, model):
  layers = {

            '1' : 'conv1_1',
            '2' : 'conv2_1',
            '4' : 'conv3_1',
            '7' : 'conv4_1',
            '8' : 'conv4_2',
            '10' : 'conv5_1'
  }
  x = image
  Features = {}
  i = 0
  for name in model_layers.keys():
    x = model_layers[name](x)
    if name in layers:
      Features[layers[name]] = x
    i += 1
    if (i > 10):
      break
  return Features

def gram_matrix(tensor):
  b, c, h, w = tensor.size()
  tensor = tensor.view(c, h*w)
  gram = torch.mm(tensor, tensor.t())
  return gram

def content_loss(target_conv4_2, content_conv4_2):
  loss = torch.mean((target_conv4_2 - content_conv4_2)**2)
  return loss

style_weights_1 = {
  'conv1_1' : 0.2,
  'conv2_1' : 0.4,
  'conv3_1' : 0.3,
  'conv4_1' : 0.9,
  'conv5_1' : 1.0
}

style_weights_2 = {
  'conv1_1' : 1.0,
  'conv2_1' : 0.4,
  'conv3_1' : 0.3,
  'conv4_1' : 0.2,
  'conv5_1' : 0.1
}

style_weights_3 = {
  'conv1_1' : 0.2,
  'conv2_1' : 1.0,
  'conv3_1' : 0.3,
  'conv4_1' : 0.2,
  'conv5_1' : 0.1
}

style_weights_4 = {
  'conv1_1' : 0.2,
  'conv2_1' : 0.4,
  'conv3_1' : 1.0,
  'conv4_1' : 0.2,
  'conv5_1' : 0.1
}

style_weights_5 = {
  'conv1_1' : 0.2,
  'conv2_1' : 0.4,
  'conv3_1' : 0.3,
  'conv4_1' : 1.0,
  'conv5_1' : 0.1
}

style_weights_arr = [style_weights_1, style_weights_2, style_weights_3, style_weights_4, style_weights_5]

def style_loss(style_weights, target_features, style_grams):
  loss = 0
  for layer in style_weights:
    target_f = target_features[layer]
    target_gram = gram_matrix(target_f)
    style_gram = style_grams[layer]
    b, c, h, w = target_f.shape
    layer_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
    loss += layer_loss/(c*h*w)
  return loss

def total_loss(c_loss, s_loss, alpha, beta):
  loss = alpha * c_loss + beta * s_loss
  return loss

def glcm_features(image):
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    glcm = graycomatrix(gray_image, [1], [0], 256, symmetric=True, normed=True)
    contrast = graycoprops(glcm, 'contrast').flatten()
    correlation = graycoprops(glcm, 'correlation').flatten()
    energy = graycoprops(glcm, 'energy').flatten()
    homogeneity = graycoprops(glcm, 'homogeneity').flatten()
    return contrast, correlation, energy, homogeneity

def calculate_similarity(feature1, feature2):
    distance = 0
    for f1, f2 in zip(feature1, feature2):
        distance += (f1 - f2)**2
    return distance**0.5

def imread(path):
    img = cv2.imread(path).astype(np.float)
    if len(img.shape) == 2:
        # grayscale
        img = np.dstack((img,img,img))
    elif img.shape[2] == 4:
        # PNG with alpha channel
        img = img[:,:,:3]
    return img

def imsave(path, img):
    img = np.clip(img, 0, 255).astype(np.uint8)
    Image.fromarray(img).save(path, quality=95)

# Augment Dataset

In [45]:
model = YOLO('./models/custom-trained-yolov8-seg.pt')
model.to(device)

<class 'ultralytics.nn.tasks.SegmentationModel'>


YOLO(
  (model): SegmentationModel(
    (model): Sequential(
      (0): Conv(
        (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): Conv(
        (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (2): C2f(
        (cv1): Conv(
          (conv): Conv2d(160, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (cv2): Conv(
          (conv): Conv2d(400, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, trac

In [46]:
def nst_data_augment(content_img_path, style_img_path, final_image_name, final_image_path, style_weights):
    content_p = preprocess(content_img_path)
    style_p = preprocess(style_img_path)
    content_p = content_p.to(device)
    style_p = style_p.to(device)

    #Getting the features from the image and calculation of Gram Matrix

    content_f = get_features(content_p, yolo_model)
    style_f = get_features(style_p, yolo_model)
    style_grams = { layer : gram_matrix(style_f[layer]) for layer in style_f }
    target = content_p.clone().requires_grad_(True).to(device)
    target_f = get_features(target, yolo_model)

    #Calculation of the losses and their optimization.

    optimizer = optim.Adam([target], lr = 0.08)
    alpha = 1
    beta = 1e6
    epochs = 1501
    show_every = 500
    results = []
    for i in range(epochs):
        target_f = get_features(target, yolo_model)

        c_loss = content_loss(target_f['conv4_2'], content_f['conv4_2'])
        s_loss = style_loss(style_weights, target_f, style_grams)

        t_loss = total_loss(c_loss, s_loss, alpha, beta)

        optimizer.zero_grad()
        t_loss.backward()
        optimizer.step()

        if i % show_every == 0:
            print("Total loss at epoch {}: {}".format(i, t_loss))
            results.append(deprocess(target.detach()))

    target_copy = deprocess(target.detach())
    content_copy = deprocess(content_p)
    plt.imsave("target.png", target_copy)
    plt.imsave("content.png", content_copy)
    imgcon = cv2.imread('content.png')
    imgcon = cv2.cvtColor(imgcon, cv2.COLOR_BGR2RGB)
    H, W, _ = imgcon.shape
    results = model(imgcon)
    i = 0
    mask_present = 0
    for result in results:
        for j, mask in enumerate(result.masks.data):
            mask = mask.cpu().numpy() * 255
            mask  =cv2.resize(mask, (W, H))
            cv2.imwrite('./mask.png', mask)
            if i == 0:
                mask_present = 1
                break
            i += 1
    if (mask_present):
        imgtar = cv2.imread("target.png")
        imgtar = cv2.cvtColor(imgtar, cv2.COLOR_BGR2RGB)
        imgtar = imgtar.astype(np.uint8)
        imgtar = cv2.resize(imgtar, (W, H))
        mask = cv2.imread('./mask.png', cv2.IMREAD_GRAYSCALE)
        _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        masked_overlay = cv2.bitwise_or(imgtar, imgtar, mask=binary_mask)
        masked_overlay = masked_overlay.astype(np.uint8)
        inverted_mask = cv2.bitwise_not(binary_mask.astype(np.uint8))
        roi = cv2.bitwise_and(imgcon, imgcon, mask=inverted_mask)
        result_image = cv2.add(roi, masked_overlay)
        plt.imsave(os.path.join(final_image_path, final_image_name), result_image)


def nst_data_augment_without_segmentation(content_img_path, style_img_path, final_image_name, final_image_path, style_weights):
    content_p = preprocess(content_img_path)
    style_p = preprocess(style_img_path)
    content_p = content_p.to(device)
    style_p = style_p.to(device)

    #Getting the features from the image and calculation of Gram Matrix

    content_f = get_features(content_p, yolo_model)
    style_f = get_features(style_p, yolo_model)
    style_grams = { layer : gram_matrix(style_f[layer]) for layer in style_f }
    target = content_p.clone().requires_grad_(True).to(device)
    target_f = get_features(target, yolo_model)

    #Calculation of the losses and their optimization.

    optimizer = optim.Adam([target], lr = 0.08)
    alpha = 1
    beta = 1e6
    epochs = 1501
    show_every = 500
    results = []
    for i in range(epochs):
        target_f = get_features(target, yolo_model)

        c_loss = content_loss(target_f['conv4_2'], content_f['conv4_2'])
        s_loss = style_loss(style_weights, target_f, style_grams)

        t_loss = total_loss(c_loss, s_loss, alpha, beta)

        optimizer.zero_grad()
        t_loss.backward()
        optimizer.step()

        if i % show_every == 0:
            print("Total loss at epoch {}: {}".format(i, t_loss))
            results.append(deprocess(target.detach()))

    target_copy = deprocess(target.detach())
    plt.imsave(os.path.join(final_image_path, final_image_name), target_copy)

In [47]:
def augment_plant_dataset(content_data_dir, style_data_dir_list, final_output_directory, with_segmentation, style_weights, num_images):
    for healthy_image_file in os.listdir(content_data_dir):
        healthy_image_path = os.path.join(content_data_dir, healthy_image_file)
        if os.path.isfile(healthy_image_path):
            if (num_images > 0):
                num_images -= 1
            else:
                break
            healthy_image = cv2.imread(healthy_image_path)
            if healthy_image is not None: 
                healthy_features = glcm_features(healthy_image)
                for diseased_folder in style_data_dir_list:
                    best_similarity = float('inf')
                    best_diseased_leaf_image = None
                    for diseased_image_file in os.listdir(os.path.join(diseased_folder)):
                        diseased_image_path = os.path.join(diseased_folder, diseased_image_file)
                        if os.path.isfile(diseased_image_path):
                            diseased_image = cv2.imread(diseased_image_path)
                            if diseased_image is not None:  # Check if image is loaded successfully
                                diseased_features = glcm_features(diseased_image)
                                similarity = calculate_similarity(healthy_features, diseased_features)
                                if similarity < best_similarity:
                                    best_similarity = similarity
                                    best_diseased_leaf_image = diseased_image_path
                    if best_diseased_leaf_image is not None:
                        # print(f"For healthy leaf image {healthy_image_file}, most similar diseased leaf image in folder {diseased_folder} is {best_diseased_leaf_image} with similarity score {best_similarity}")
                        index = diseased_folder.rfind("/")
                        diseased_folder_name = diseased_folder[index+1:]
                        augmented_path = os.path.join(final_output_directory, diseased_folder_name)
                        os.makedirs(augmented_path, exist_ok=True)
                        index = healthy_image_file.rfind(".")
                        augmented_image_name = healthy_image_file[:index] + "_augmented." + healthy_image_file[index+1:]
                        print("Transfering the style for : " + str(healthy_image_path))
                        if with_segmentation:
                            nst_data_augment(healthy_image_path, best_diseased_leaf_image, augmented_image_name, augmented_path, style_weights)
                        else:
                            nst_data_augment_without_segmentation(healthy_image_path, best_diseased_leaf_image, augmented_image_name, augmented_path, style_weights)
                    else:
                        print(f"No suitable diseased leaf image found for healthy leaf image {healthy_image_file} in folder {diseased_folder}")
            else:
                print(f"Error loading healthy leaf image: {healthy_image_path}")
        else:
            print(f"Invalid file: {healthy_image_path}")


In [48]:
# nst_data_augment('../data/plantvillage-dataset/segmented/apple_healthy/00a6039c-e425-4f7d-81b1-d6b0e668517e___RS_HL 7669_final_masked.jpg', '../data/shortlisted-style-images/apple_black_rot/4dadb9f1-27b1-4d3c-8111-d1602febd585___JR_FrgE.S 8632_final_masked.jpg', 'ouput.jpg', '../data/experiment-augmentation/')

In [49]:
num_required = 10

Augment Apple Dataset

In [50]:
for i in range(len(style_weights_arr)):
    augment_plant_dataset(content_images_dir_list[0], list(style_images_apple_dir_list.keys()), '../data/plantvillage-augmented/weights_'+str(i)+"/", True, style_weights_arr[i], num_required)

Transfering the style for : ../data/PlantVillage/train/Apple___healthy/4c7d424b-d418-4b72-84a9-2a639eabb5ec___RS_HL 5734.JPG
Total loss at epoch 0: 908767936.0
Total loss at epoch 500: 51322.01953125
Total loss at epoch 1000: 13843.5244140625
Total loss at epoch 1500: 1525989.625
<ultralytics.models.yolo.segment.predict.SegmentationPredictor object at 0x352ce8a60>

0: 640x640 1 Healthy_Leaf, 134.8ms
Speed: 15.0ms preprocess, 134.8ms inference, 290.9ms postprocess per image at shape (1, 3, 640, 640)
Transfering the style for : ../data/PlantVillage/train/Apple___healthy/4c7d424b-d418-4b72-84a9-2a639eabb5ec___RS_HL 5734.JPG
Total loss at epoch 0: 895906112.0
Total loss at epoch 500: 107494.1484375
Total loss at epoch 1000: 6359823.5
Total loss at epoch 1500: 1744862.25

0: 640x640 1 Healthy_Leaf, 154.2ms
Speed: 14.1ms preprocess, 154.2ms inference, 128.7ms postprocess per image at shape (1, 3, 640, 640)
Transfering the style for : ../data/PlantVillage/train/Apple___healthy/4c7d424b-d418-4

KeyboardInterrupt: 

Augment Corn Dataset

In [None]:
augment_plant_dataset(content_images_dir_list[1], list(style_images_corn_dir_list.keys()), '../data/plantvillage-augmented/')

Augment Tomato Dataset

In [None]:
augment_plant_dataset(content_images_dir_list[2], list(style_images_tomato_dir_list.keys()), '../data/plantvillage-augmented/')

In [6]:
from ultralytics import YOLO
import cv2
import numpy as np

model = YOLO('./models/custom-trained-yolov8-seg.pt')
img = cv2.imread('../data/PlantVillage/train/Tomato___healthy/000bf685-b305-408b-91f4-37030f8e62db___GH_HL Leaf 308.1.JPG')

# Set confidence threshold
conf = 0.5

# Predict using the model
results = model.predict(img, conf=conf)

# Create a black image of the same dimensions as the input image
binary_mask = np.zeros_like(img)
binary_mask.fill(0) # Making the image black

for result in results:
    for mask, box in zip(result.masks.xy, result.boxes):
        points = np.int32([mask])
        # Use white color for the segmented region, making it visible against the black background
        cv2.fillPoly(binary_mask, points, (255, 255, 255))

# Display the binary mask
# cv2.imshow(binary_mask)
# cv2.waitKey(0)

# Save the binary mask image
cv2.imwrite("./mask.jpg", binary_mask)

<class 'ultralytics.nn.tasks.SegmentationModel'>
<ultralytics.models.yolo.segment.predict.SegmentationPredictor object at 0x171cac8e0>

0: 640x640 1 Non_Healthy_Leaf, 917.5ms
Speed: 1.6ms preprocess, 917.5ms inference, 1.4ms postprocess per image at shape (1, 3, 640, 640)


True

In [None]:



as such change kahich nahi kela, just used the backbone of YOLOv8 for getting the features out from the content and style image.