Dataset from: https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection/data, the dataset is Public Domain

Some code is from https://www.kaggle.com/code/harpdeci/yolo-nas-fruit-detection/notebook

# [Fruit Detection](https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection/data)



In [None]:
# This notebook will be used for an Object Detection task that trains a model on the fruits_detection dataset using YOLOv8

# Importing the necessary libraries
import os
import sys
import json
import random
import pathlib
import requests
import zipfile
import time

import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision import datasets
from torchvision import models
from PIL import Image
import numpy as np
import pandas as pd
import torch
import yaml 
from ultralytics import YOLO


import matplotlib.pyplot as plt
%matplotlib inline

'''from super_gradients.training import Trainer, dataloaders, models
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val
)
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import (
    PPYoloEPostPredictionCallback
)'''

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training on {device}')




# Get the dataset

In [None]:
def download_file(url="https://www.dropbox.com/scl/fi/79fa4j2sjb0ohvlv843lr/fruits_detection.zip?rlkey=0sskhcwckcgvvwknuu6w21n5q&st=jx3vrsl8&dl=1", filename="fruits_detection.zip"):
                        
    # Check to see if the datasets folder exists
    if not os.path.exists("datasets"):
        os.makedirs("datasets")
    
    # Download the file using requests
    response = requests.get(url, stream=True)

    # Create a file object and write the response content in chunks
    with open(filename, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)

    # Wait for the file to finish downloading
    while not os.path.exists(filename):
        time.sleep(1)

    # Print a success message
    print(f"Downloaded {filename} successfully.")

def extract_file(filename, data_folder):
    # Check if the file is a zip file
    if zipfile.is_zipfile(filename):
        # Open the zip file
        with zipfile.ZipFile(filename, "r") as zip_ref:
            # Extract all the files to the data folder
            zip_ref.extractall(data_folder)
            # Print a success message
            print(f"Extracted {filename} to {data_folder} successfully.")
    else:
        # Print an error message
        print(type(filename))
        print(f"{filename} is not a valid zip file.")
    
def manage_data(folder_name='fruits_detection'):
    '''Try to find the data for the exercise and return the path'''
    
    # Check common paths of where the data might be on different systems
    likely_paths= [os.path.normpath(f'/blue/practicum-ai/share/data/{folder_name}'),
                   os.path.normpath(f'/project/scinet_workshop2/data/{folder_name}'),
                   os.path.join('datasets', folder_name),
                   os.path.normpath(folder_name)]
    
    for path in likely_paths:
        if os.path.exists(path):
            print(f'Found data at {path}.')
            return path

    answer = input(f'Could not find data in the common locations. Do you know the path? (yes/no): ')

    if answer.lower() == 'yes':
        path = os.path.join(os.path.normpath(input('Please enter the path to the data folder: ')),folder_name)
        if os.path.exists(path):
            print(f'Thanks! Found your data at {path}.')
            return path
        else:
            print(f'Sorry, that path does not exist.')
    
    answer = input('Do you want to download the data? (yes/no): ')

    if answer.lower() == 'yes':

        ''' Check and see if the downloaded data is inside the .gitignore file, and adds them to the list of files to ignore if not. 
        This is to prevent the data from being uploaded to the repository, as the files are too large for GitHub.'''
        
        if os.path.exists('.gitignore'):
            with open('.gitignore', 'r') as f:
                ignore = f.read().split('\n')
        # If the .gitignore file does not exist, create a new one
        elif not os.path.exists('.gitignore'):
            with open('.gitignore', 'w') as f:
                f.write('')
            ignore = []
        else:
            ignore = []

        # Check if the .gz file is in the ignore list
        if 'fruits_detection.zip' not in ignore:
            ignore.append('fruits_detection.zip')
            
        # Check if the data/ folder is in the ignore list
        if 'datasets/' not in ignore:
            ignore.append('datasets/')

        # Write the updated ignore list back to the .gitignore file
        with open('.gitignore', 'w') as f:
            f.write('\n'.join(ignore))

        print("Updated .gitignore file.")
        print('Downloading data, this may take a minute.')
        download_file()
        print('Data downloaded, unpacking')
        extract_file("fruits_detection.zip", "datasets")
        print('Data downloaded and unpacked. Now available at datasets/fruits_detection.')
        return os.path.normpath('datasets/fruits_detection')   

    print('Sorry, I cannot find the data. Please download it manually from https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection/ and unpack it to the datasets folder.')      


data_path = manage_data() 

# Explore the dataset

In [None]:
# Download a file from Kaggle using the Kaggle API into the datasets folder

# Check if the Kaggle API key is present
if os.path.exists('kaggle.json'):
    # Load the Kaggle API key
    with open('kaggle.json', 'r') as f:
        kaggle = json.load(f)
    # Save the Kaggle API key to the environment
    os.environ['KAGGLE_USERNAME'] = kaggle['username']
    os.environ['KAGGLE_KEY'] = kaggle['key']
    # Download the dataset
    os.system('kaggle datasets download lakshaytyagi01/fruit-detection -p datasets')
else:
    print('Please upload your Kaggle API key to this notebook.')

# Unzip the downloaded file
with zipfile.ZipFile('datasets/fruit-detection.zip', 'r') as zip_ref:
    zip_ref.extractall('datasets')
    

In [None]:
# Assign the path to the dataset
data_dir = r"data/fruits_detection"

# Make a histogram of the number of images in each class
def explore_data(data_dir, show_picture=True, show_hist=True):
    
    # Examine some sample images
    if show_picture:
        # Get valid image folders 
        image_folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))] 

        sample_images = []
        for i in range(5):
            folder = random.choice(image_folders) 
            img_path = os.path.join(data_dir, folder, 'images', random.choice(os.listdir(os.path.join(data_dir, folder, 'images'))))
            sample_images.append(img_path)

        # Plot the sample images
        fig, axes = plt.subplots(1, 5, figsize=(20, 5))
        for i, img_path in enumerate(sample_images):
            img = Image.open(img_path)
            axes[i].imshow(img)
            axes[i].axis('off')
        plt.show()

    # Make a histogram of the number of images in each class
    if show_hist:
        def get_class_counts(folder_path):  # Change from data_dir to folder_path
            class_counts = {}
            labels_path = os.path.join(folder_path, 'labels')  # Add labels path
            for filename in os.listdir(labels_path):  # Update listdir
                with open(os.path.join(labels_path, filename), 'r') as f:
                    for line in f:
                        class_id = int(line.split(' ')[0])  # Assuming labels are in YOLO format
                        class_counts[class_id] = class_counts.get(class_id, 0) + 1
            return class_counts

        train_counts = get_class_counts(os.path.join(data_dir, 'train'))  # Add os.path.join
        val_counts = get_class_counts(os.path.join(data_dir, 'valid'))
        test_counts = get_class_counts(os.path.join(data_dir, 'test'))
        class_names = ['Apple', 'Banana', 'Grape', 'Orange', 'Pineapple', 'Watermelon']
        num_classes = len(class_names)

        data_counts = {
            'train': pd.Series(train_counts),
            'val': pd.Series(val_counts),
            'test': pd.Series(test_counts)
        }
        df = pd.DataFrame(data_counts)

        df.plot.bar(figsize=(10, 6))
        plt.xlabel('Class Name')
        plt.xticks(np.arange(num_classes), class_names)
        plt.ylabel('Number of Images')
        plt.title('Distribution of Images per Class')
        plt.legend()
        plt.show()

explore_data(data_dir, show_picture=True, show_hist=True)

# Create the YAML file

In [None]:
# Create a YAML file for the YOLOv8 model configuration

def create_yaml(data_dir, class_names, yaml_file='fruits_detection_data.yaml'):
    """
    Creates a YOLOv8 data.yaml file. YAML stands for "YAML Ain't Markup Language" and is a human-readable data serialization format.
    A YAML file is used to define the dataset configuration for training a YOLOv8 model.

    Args:
        data_dir (str): Path to the dataset root directory.
        class_names (list): List of class names.
        yaml_file (str): Name of the YAML file to save. Defaults to 'data.yaml'.
    """

    yaml_dict = {
        # 'path': data_dir,  # Path to your dataset
        'train': data_dir + '/train/images',  # Relative path to training images
        'val': data_dir + '/valid/images',    # Relative path to validation images
        'test': data_dir + '/test/images',    # Relative path to testing images

        'num_classes': len(class_names),   # Number of classes
        'names': class_names      # List of class names
    }

    with open(yaml_file, 'w') as outfile:
        yaml.dump(yaml_dict, outfile, default_flow_style=False)

    print(f'YAML file created: {yaml_file}')

data_dir = 'data/fruits_detection'
class_names = ['Apple', 'Banana', 'Grape', 'Orange', 'Pineapple', 'Watermelon']

create_yaml(data_dir, class_names) 

# Create the model

In [None]:
# Make the YOLOv8 model
model = YOLO('yolov8n.yaml')
results = model.train(data='fruits_detection_data.yaml', imgsz=640, epochs=3)