# This code is for providing the data augmentation by using the YOLOv8 Architecture.

Input : 
1. Directory Names of content images. [Healthy Leaf Directory of Different Plants]
2. Directory Names of style images. [Hand Picked Infected Leaf Images of Different Plants]

Output :
1. Directory for all the style images where the augmented will be saved. [14 directories in case of segmented PlantVillage Dataset]
2. This output will be used by the Augmentation Validation Classifier to include only relevent imgaes in the Actual Dataset.

Functioning of this code : 
1. We initiate the style transfer using the hand-picked styles from each class on the healty image datast for each crop and save the images in respective folder.

# Loading the Model

In [None]:
!pip install ultralytics

In [None]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
from torch import optim
from ultralytics import YOLO
import matplotlib.pyplot as plt
from torchvision import transforms as T
from scipy.spatial.distance import euclidean
from skimage.feature import graycomatrix, graycoprops

# device: str = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda"

model = YOLO('/kaggle/input/custom-trained-yolov8-plant-disease/custom-trained-yolov8-seg.pt')
yolo_model = model.model.model
for parameters in yolo_model:
    yolo_model.requires_grad_(False)

In [None]:
i = 0
model_layers = {}
for name, layer in model._modules.items():
    for name_l, layer_l in layer._modules.items():
        for name_ll, layer_ll in layer_l._modules.items():
            model_layers[str(i)] = layer_ll
            i += 1
yolo_model.to(device)

# Augmentation code starts from here.

In [None]:
content_images_dir_list = [
    '/kaggle/input/plantvillage/PlantVillage/train/Apple___healthy',
    '/kaggle/input/plantvillage/PlantVillage/train/Corn_(maize)___healthy',
    '/kaggle/input/plantvillage/PlantVillage/train/Grape___healthy',
    '/kaggle/input/plantvillage/PlantVillage/train/Pepper,_bell___healthy',
    '/kaggle/input/plantvillage/PlantVillage/train/Strawberry___healthy',
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___healthy'
]

style_images_apple_dir_list = {
    '/kaggle/input/plantvillage/PlantVillage/train/Apple___Apple_scab': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Apple___Black_rot': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Apple___Cedar_apple_rust': 100,
}

style_images_corn_dir_list = {
    '/kaggle/input/plantvillage/PlantVillage/train/Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Corn_(maize)___Common_rust_': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Corn_(maize)___Northern_Leaf_Blight': 100, 
}

style_images_grape_dir_list = {
    '/kaggle/input/plantvillage/PlantVillage/train/Grape___Black_rot':100,
    '/kaggle/input/plantvillage/PlantVillage/train/Grape___Esca_(Black_Measles)': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Grape___Leaf_blight_(Isariopsis_Leaf_Spot)':100,
}

style_images_pepper_dir_list = {
    '/kaggle/input/plantvillage/PlantVillage/train/Pepper,_bell___Bacterial_spot': 100,
}

style_images_strawberry_dir_list = {
    '/kaggle/input/plantvillage/PlantVillage/train/Strawberry___Leaf_scorch': 100,
}

style_images_tomato_dir_list = {
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Bacterial_spot': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Early_blight': 100, 
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Late_blight': 100, 
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Leaf_Mold': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Septoria_leaf_spot': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Spider_mites Two-spotted_spider_mite': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Target_Spot': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Tomato_mosaic_virus': 100,
    '/kaggle/input/plantvillage/PlantVillage/train/Tomato___Tomato_Yellow_Leaf_Curl_Virus': 100,
}

In [None]:
def preprocess(img_path, max_size = 640):
  image = Image.open(img_path).convert('RGB')
  img_transforms = T.Compose([
      T.ToTensor(),  # (224, 224, 3) -> (3, 224, 224)
      T.Normalize(mean = [0.485, 0.456, 0.406],
                  std = [0.229, 0.224, 0.225])
  ])
  image = img_transforms(image)
  image = image.unsqueeze(0) # (3, 224, 224) -> (1, 3, 224, 224)
  return image

def deprocess(tensor):
  image = tensor.to('cpu').clone()
  image = image.numpy()
  image = image.squeeze(0)
  image = image.transpose(1, 2, 0)
  # denormalizing the image
  image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
  image = image.clip(0, 1)
  return image

def get_features(image, model):
  layers = {

            '1' : 'conv1_1',
            '2' : 'conv2_1',
            '4' : 'conv3_1',
            '7' : 'conv4_1',
            '8' : 'conv4_2',
            '10' : 'conv5_1'
  }
  x = image
  Features = {}
  i = 0
  for name in model_layers.keys():
    x = model_layers[name](x)
    if name in layers:
      Features[layers[name]] = x
    i += 1
    if (i > 10):
      break
  return Features

def gram_matrix(tensor):
  b, c, h, w = tensor.size()
  tensor = tensor.view(c, h*w)
  gram = torch.mm(tensor, tensor.t())
  return gram

def content_loss(target_conv4_2, content_conv4_2):
  loss = torch.mean((target_conv4_2 - content_conv4_2)**2)
  return loss

style_weights_1 = {

    'conv1_1' : 0.2,
    'conv2_1' : 0.2,
    'conv3_1' : 0.5,
    'conv4_1' : 1.0,
    'conv5_1' : 0.2
}

style_weights_arr = [style_weights_1]

def style_loss(style_weights, target_features, style_grams):
  loss = 0
  for layer in style_weights:
    target_f = target_features[layer]
    target_gram = gram_matrix(target_f)
    style_gram = style_grams[layer]
    b, c, h, w = target_f.shape
    layer_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
    loss += layer_loss/(c*h*w)
  return loss

def total_loss(c_loss, s_loss, alpha, beta):
  loss = alpha * c_loss + beta * s_loss
  return loss

def glcm_features(image):
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    glcm = graycomatrix(gray_image, [1], [0], 256, symmetric=True, normed=True)
    contrast = graycoprops(glcm, 'contrast').flatten()
    correlation = graycoprops(glcm, 'correlation').flatten()
    energy = graycoprops(glcm, 'energy').flatten()
    homogeneity = graycoprops(glcm, 'homogeneity').flatten()
    return contrast, correlation, energy, homogeneity

def calculate_similarity(feature1, feature2):
    distance = 0
    for f1, f2 in zip(feature1, feature2):
        distance += (f1 - f2)**2
    return distance**0.5

def imread(path):
    img = cv2.imread(path).astype(np.float)
    if len(img.shape) == 2:
        # grayscale
        img = np.dstack((img,img,img))
    elif img.shape[2] == 4:
        # PNG with alpha channel
        img = img[:,:,:3]
    return img

def imsave(path, img):
    img = np.clip(img, 0, 255).astype(np.uint8)
    Image.fromarray(img).save(path, quality=95)

# Augment Dataset

**Upload custom trained yolov8 before this**

In [None]:
model = YOLO('/kaggle/input/custom-trained-yolov8-plant-disease/custom-trained-yolov8-seg.pt')
model.to(device)

In [None]:
def nst_data_augment(content_img_path, style_img_path, final_image_name, final_image_path, style_weights):
    content_p = preprocess(content_img_path)
    style_p = preprocess(style_img_path)
    content_p = content_p.to(device)
    style_p = style_p.to(device)

    #Getting the features from the image and calculation of Gram Matrix

    content_f = get_features(content_p, yolo_model)
    style_f = get_features(style_p, yolo_model)
    style_grams = { layer : gram_matrix(style_f[layer]) for layer in style_f }
    target = content_p.clone().requires_grad_(True).to(device)
    target_f = get_features(target, yolo_model)

    #Calculation of the losses and their optimization.

    optimizer = optim.Adam([target], lr = 0.08)
    alpha = 1
    beta = 1e6
    epochs = 801
    show_every = 200
    results = []
    for i in range(epochs):
        target_f = get_features(target, yolo_model)

        c_loss = content_loss(target_f['conv4_2'], content_f['conv4_2'])
        s_loss = style_loss(style_weights_arr[style_weights], target_f, style_grams)

        t_loss = total_loss(c_loss, s_loss, alpha, beta)

        optimizer.zero_grad()
        t_loss.backward()
        optimizer.step()

        if i % show_every == 0:
            print("Total loss at epoch {}: {}".format(i, t_loss))
            results.append(deprocess(target.detach()))

    target_copy = deprocess(target.detach())
    content_copy = deprocess(content_p)
    plt.imsave("target.png", target_copy)
    plt.imsave("content.png", content_copy)
    imgcon = cv2.imread('content.png')
    imgcon = cv2.cvtColor(imgcon, cv2.COLOR_BGR2RGB)
    H, W, _ = imgcon.shape
    results = model(imgcon)
    i = 0
    mask_present = 0
    for result in results:
        for j, mask in enumerate(result.masks.data):
            mask = mask.cpu().numpy() * 255
            mask  =cv2.resize(mask, (W, H))
            cv2.imwrite('./mask.png', mask)
            if i == 0:
                mask_present = 1
                break
            i += 1
    if (mask_present):
        imgtar = cv2.imread("target.png")
        imgtar = cv2.cvtColor(imgtar, cv2.COLOR_BGR2RGB)
        imgtar = imgtar.astype(np.uint8)
        imgtar = cv2.resize(imgtar, (W, H))
        mask = cv2.imread('./mask.png', cv2.IMREAD_GRAYSCALE)
        _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        masked_overlay = cv2.bitwise_or(imgtar, imgtar, mask=binary_mask)
        masked_overlay = masked_overlay.astype(np.uint8)
        inverted_mask = cv2.bitwise_not(binary_mask.astype(np.uint8))
        roi = cv2.bitwise_and(imgcon, imgcon, mask=inverted_mask)
        result_image = cv2.add(roi, masked_overlay)
        plt.imsave(os.path.join(final_image_path, final_image_name), result_image)


def nst_data_augment_without_segmentation(content_img_path, style_img_path, final_image_name, final_image_path, style_weights):
    content_p = preprocess(content_img_path)
    style_p = preprocess(style_img_path)
    content_p = content_p.to(device)
    style_p = style_p.to(device)

    #Getting the features from the image and calculation of Gram Matrix

    content_f = get_features(content_p, yolo_model)
    style_f = get_features(style_p, yolo_model)
    style_grams = { layer : gram_matrix(style_f[layer]) for layer in style_f }
    target = content_p.clone().requires_grad_(True).to(device)
    target_f = get_features(target, yolo_model)

    #Calculation of the losses and their optimization.

    optimizer = optim.Adam([target], lr = 0.08)
    alpha = 1
    beta = 1e6
    epochs = 801
    show_every = 200
    results = []
    for i in range(epochs):
        target_f = get_features(target, yolo_model)

        c_loss = content_loss(target_f['conv4_2'], content_f['conv4_2'])
        s_loss = style_loss(style_weights_arr[style_weights], target_f, style_grams)

        t_loss = total_loss(c_loss, s_loss, alpha, beta)

        optimizer.zero_grad()
        t_loss.backward()
        optimizer.step()

        if i % show_every == 0:
            print("Total loss at epoch {}: {}".format(i, t_loss))
            results.append(deprocess(target.detach()))

    target_copy = deprocess(target.detach())
    plt.imsave(os.path.join(final_image_path, final_image_name), target_copy)

In [None]:
weights_data = {
    'Apple___Cedar_apple_rust': 1,
    'Apple___Apple_scab': 3,
    'Apple___Black_rot': 4,
    'Corn_(maize)___Common_rust_': 0,
    'Corn_(maize)___Northern_Leaf_Blight': 0,
    'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': 0,
    'Grape___Black_rot': 0,
    'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)': 0,
    'Grape___Esca_(Black_Measles)': 0,
    'Pepper,_bell___Bacterial_spot': 0,
    'Strawberry___Leaf_scorch': 1,
    'Tomato___Target_Spot': 0,
    'Tomato___Late_blight': 0,
    'Tomato___Tomato_mosaic_virus': 0,
    'Tomato___Leaf_Mold': 0,
    'Tomato___Bacterial_spot': 0,
    'Tomato___Early_blight': 0,
    'Tomato___Tomato_Yellow_Leaf_Curl_Virus': 0,
    'Tomato___Spider_mites Two-spotted_spider_mite': 0,
    'Tomato___Septoria_leaf_spot': 0
}

In [None]:
def augment_plant_dataset(content_data_dir, style_data_dir_list, final_output_directory, with_segmentation, skip):
    num_skip = skip
    for healthy_image_file in os.listdir(content_data_dir):
        healthy_image_path = os.path.join(content_data_dir, healthy_image_file)
        if os.path.isfile(healthy_image_path):
            healthy_image = cv2.imread(healthy_image_path)
            if healthy_image is not None: 
                if num_skip > 0:
                    num_skip -= 1
                    continue
                healthy_features = glcm_features(healthy_image)
                for diseased_folder in style_data_dir_list:
                    best_similarity = float('inf')
                    best_diseased_leaf_image = None
                    for diseased_image_file in os.listdir(os.path.join(diseased_folder)):
                        diseased_image_path = os.path.join(diseased_folder, diseased_image_file)
                        if os.path.isfile(diseased_image_path):
                            diseased_image = cv2.imread(diseased_image_path)
                            if diseased_image is not None:  # Check if image is loaded successfully
                                diseased_features = glcm_features(diseased_image)
                                similarity = calculate_similarity(healthy_features, diseased_features)
                                if similarity < best_similarity:
                                    best_similarity = similarity
                                    best_diseased_leaf_image = diseased_image_path
                    if best_diseased_leaf_image is not None:
                        # print(f"For healthy leaf image {healthy_image_file}, most similar diseased leaf image in folder {diseased_folder} is {best_diseased_leaf_image} with similarity score {best_similarity}")
                        index = diseased_folder.rfind("/")
                        diseased_folder_name = diseased_folder[index+1:]
                        augmented_path = os.path.join(final_output_directory, diseased_folder_name)
                        os.makedirs(augmented_path, exist_ok=True)
                        index = healthy_image_file.rfind(".")
                        augmented_image_name = healthy_image_file[:index] + "_augmented." + healthy_image_file[index+1:]
                        print("Transfering the style for : " + str(healthy_image_path))
                        if with_segmentation:
                            nst_data_augment(healthy_image_path, best_diseased_leaf_image, augmented_image_name, augmented_path, weights_data[diseased_folder_name])
                        else:
                            nst_data_augment_without_segmentation(healthy_image_path, best_diseased_leaf_image, augmented_image_name, augmented_path, weights_data[diseased_folder_name])
                    else:
                        print(f"No suitable diseased leaf image found for healthy leaf image {healthy_image_file} in folder {diseased_folder}")
            else:
                print(f"Error loading healthy leaf image: {healthy_image_path}")
        else:
            print(f"Invalid file: {healthy_image_path}")


**Make changes here as per augmentation to be made**

In [None]:
augment_plant_dataset(content_images_dir_list[5], list(style_images_tomato_dir_list.keys()), '/kaggle/working/plantvillage-augmented-tomato-3', True, 548)

In [None]:
!zip -r file-tomato-3.zip /kaggle/working/plantvillage-augmented-tomato-3/Tomato___Tomato_Yellow_Leaf_Curl_Virus