Author: Franziska 
Trained with Vertex AI and not locally

# Experiment with YOLO

## Setup and Libaries

In [None]:
#%pip install ultralytics
#w!pip install opencv-python-headless

In [None]:
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import PIL
import os
import pathlib
import glob
import random

from sklearn.utils import shuffle
from matplotlib.patches import Rectangle, Polygon
from ultralytics import YOLO
from PIL import Image
from IPython.display import display

warnings.simplefilter('ignore')

In [None]:
# parameters
base_path = "/home/jupyter/Remote Sensing Data.v2i.yolov8"
yaml_path = base_path + "/data.yaml"
test_images_path = base_path + "/test/images"
predicted_images_path = "/home/jupyter/runs/detect"
log_dir = "/home/jupyter/runs/detect/train"  # Adjust this path based on your setup

# Define data paths
data_path = "/home/jupyter/Remote Sensing Data.v2i.yolov8"
train_path = os.path.join(data_path, 'train')
val_path = os.path.join(data_path, 'val')


base_model = "yolov8n.pt"

## Data Exploration

In [None]:
# Define data directory
labels_path = os.path.join(base_path, 'train/labels')
images_path = os.path.join(base_path, 'train/images')

# Class labels from data.yaml
labels = ['Agriculture', 'Airport', 'Beach', 'City', 'Desert', 'Forest', 'Grassland', 'Highway', 'Lake', 'Mountain', 'Parking', 'Port', 'Railway', 'River']
# Define a color map for different classes
colors = plt.cm.get_cmap('hsv', len(labels))

In [None]:
def load_labels(label_file):
    """
    Load label data from a .txt file.

    Parameters:
    label_file (str): Path to the label file

    Returns:
    list: List of annotations (polygons and rectangles) with classes
    """
    annotations = []
    with open(label_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            try:
                class_id = int(parts[0])
                if len(parts) == 5:
                    # Rectangle: class_id, x_center, y_center, width, height
                    xc, yc, w, h = [float(x) for x in parts[1:]]
                    annotations.append((class_id, 'rectangle', xc, yc, w, h))
                else:
                    # Polygon: class_id, x1, y1, x2, y2, ..., xN, yN
                    vertices = []
                    for i in range(1, len(parts), 2):
                        x = float(parts[i])
                        y = float(parts[i + 1])
                        vertices.append((x, y))
                    annotations.append((class_id, 'polygon', vertices))
            except ValueError:
                continue  # Skip lines with invalid values
    return annotations

def visualize_annotations(image_file, label_file, ax):
    """
    Visualize annotations (polygons and rectangles) on an image.

    Parameters:
    image_file (str): Path to the image
    label_file (str): Path to the label file
    ax (matplotlib.axes._subplots.AxesSubplot): The subplot axis to display the image
    """
    image = cv2.imread(image_file)
    height, width, _ = image.shape
    annotations = load_labels(label_file)

    ax.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    ax.set_title(os.path.basename(image_file))

    for annotation in annotations:
        class_id, annotation_type, *data = annotation

        color = colors(class_id)

        if annotation_type == 'rectangle':
            xc, yc, w, h = data
            xmin = int((xc - w / 2) * width)
            ymin = int((yc - h / 2) * height)
            xmax = int((xc + w / 2) * width)
            ymax = int((yc + h / 2) * height)
            ax.add_patch(Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, edgecolor=color, fill=False, linewidth=2))
            ax.text(xmin, ymin, labels[class_id], bbox={'facecolor': color, 'alpha': 0.5}, fontsize=10, color='white')

        elif annotation_type == 'polygon':
            vertices = [(x * width, y * height) for x, y in data[0]]
            ax.add_patch(Polygon(vertices, closed=True, edgecolor=color, fill=False, linewidth=2))
            ax.text(vertices[0][0], vertices[0][1], labels[class_id], bbox={'facecolor': color, 'alpha': 0.5}, fontsize=10, color='white')

    ax.axis('off')

In [None]:
# Get list of image files
image_files = [os.path.join(images_path, f) for f in os.listdir(images_path) if f.endswith('.jpg')]
label_files = [os.path.join(labels_path, os.path.splitext(f)[0] + '.txt') for f in os.listdir(images_path) if f.endswith('.jpg')]

# Randomly select 9 images and their corresponding labels
selected_indices = random.sample(range(len(image_files)), 9)
selected_image_files = [image_files[i] for i in selected_indices]
selected_label_files = [label_files[i] for i in selected_indices]

# Visualize in a 3x3 grid
fig, axs = plt.subplots(3, 3, figsize=(15, 15))
for ax, image_file, label_file in zip(axs.flat, selected_image_files, selected_label_files):
    visualize_annotations(image_file, label_file, ax)
plt.tight_layout()
plt.show()

## Model training

In [None]:
# Load the model
model = YOLO(base_model)
# Train the model
results = model.train(data=os.path.join(data_path, 'data.yaml'), epochs=20, imgsz=640, save=True, fraction=1)


In [None]:
# Load the model
model = YOLO("../models/yolov8n.pt")

# Train the model
results = model.train(data="/home/franziska/code/FranziskaHaisch/EnviroClass/raw_data/satellite_images_14_classes/data.yaml", epochs=100, imgsz=640)

## Evaluate Model

In [None]:
# Evaluate the model on the validation set
results_val = model.val(data=yaml_path)

In [None]:
# Display the post-training images
def display_images(post_training_files_path, image_files):
    for image_file in image_files:
        image_path = os.path.join(post_training_files_path, image_file)
        img = cv2.imread(image_path)
        if img is not None:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            plt.figure(figsize=(10, 10), dpi=120)
            plt.imshow(img)
            plt.axis('off')
            plt.show()
        else:
            print(f"Could not load image: {image_file}")

# List of image files to display
image_files = [
    'confusion_matrix_normalized.png',
    'F1_curve.png',
    'P_curve.png',
    'R_curve.png',
    'PR_curve.png',
    'results.png'
]

# Path to the directory containing the images
post_training_files_path = '/home/jupyter/runs/detect/tval'

# Display the images
display_images(post_training_files_path, image_files)

## Predict

In [None]:
# Predict on new images
predictions = model.predict(source=test_images_path, save=True)

In [None]:
# Get the list of predicted images
predicted_images = glob.glob("/home/jupyter/runs/detect/predict" + "/*.jpg")

# Shuffle the list of images to select 9 random images
random.shuffle(predicted_images)

# Create a 3x3 grid of images
fig, axs = plt.subplots(3, 3, figsize=(15, 15))
for ax, img_path in zip(axs.flat, predicted_images[:9]):
    img = cv2.imread(img_path)
    if img is not None:
        ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        ax.axis('off')
    else:
        print(f"Could not load image: {img_path}")  # Debugging info
plt.tight_layout()
plt.show()