#### Scripts for training a YOLOv8 model 

In [None]:
###### Script to change YOLO labels in txt files from 15 to 0 (and 16 to 1 if needed etc) ######
### NB. I dont like keeping unessecary classes in the model so I only keep what I want to train on. The labelling method I use starts any new labels from laebl no. 15 ###
 

## imports ##
import os

# Function to change labels in a single file
def change_labels(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    new_lines = []
    for line in lines:
        parts = line.strip().split()
        if parts[0] == '15':
            parts[0] = '0'
        #elif parts[0] == '16':
            #parts[0] = '1'
        new_line = ' '.join(parts)
        new_lines.append(new_line)
    
    with open(file_path, 'w') as f:
        f.write('\n'.join(new_lines))

# Specify the folder where your .txt files are located
folder_path = '../FV/' ##etc

# Loop through each .txt file in the folder
for filename in os.listdir(folder_path):
    if filename.endswith('.txt'):
        file_path = os.path.join(folder_path, filename)
        change_labels(file_path)


In [None]:
###### NB. YOU DO NOT RUN THIS SCRIPT. INSTEAD, THIS IS THE CONTENTS OF THE DATA.YAML FILE USED FOR YOLOv8 TRAINING (ie Save as keyhole.yaml file) ######


path: data/
train: /home/User/data/train/ ##etc
val: /home/User/data/validation/ ##etc
#test: /home/User/data/test/ ##etc

#Classes
nc: 1 # replace according to your number of classes

#classes names
#replace all class names list with your classes names
names: ['hole', #'shell', etc]
]
    

In [None]:
###### Scripts for training a YOLOv8 model ######



## imports ##
from IPython import display
display.clear_output()
import ultralytics
ultralytics.checks()
from ultralytics import YOLO
from IPython.display import display, Image



# Load the pretrained YOLO model from a .pt file.
model = YOLO('yolov8n.pt') ### NB. this is downloaded automatically on first use ###

# Define individual training hyperparameters.
custom_lr0 = 0.01
custom_momentum = 0.937
custom_weight_decay = 0.0005

# Train the model using your YAML file and additional parameters.
results = model.train(
    data='keyhole.yaml',   ##etc
    epochs=100,
    imgsz=320,
    batch=16,
    lr0=custom_lr0,              # initial learning rate
    momentum=custom_momentum,    # momentum
    weight_decay=custom_weight_decay,  # weight decay
    optimizer='SGD',             # change the optimizer if needed (e.g. 'Adam', 'AdamW')
    device=0,
    project='runs/keyhole',  # directory to save the training results
    name='keyhole_yolo_model',
    exist_ok=True,               # allow overwriting the existing experiment directory

    # Optional augmentation parameters.
    # These parameters are provided as individual arguments.
    mosaic=1.0,     # mosaic augmentation probability (if supported)
    mixup=0.0,      # mixup augmentation probability (if supported)
    degrees=0.0,    # rotation degrees
    translate=0.1,  # translation factor
    scale=0.5,      # scale factor
    flipud=0.0,     # vertical flip probability
    fliplr=0.5     # horizontal flip probability
)

print("Training complete. The model weights have been saved in the specified project directory.")

## Script for creating segments from yolo model

In [None]:
###### Script to create standardized masks using YOLO and SAM models (this one is set up for the FV Keyhole), then derives metrics and saves values to a .csv ######


import os
import numpy as np
from PIL import Image
from ultralytics import YOLO
from segment_anything import sam_model_registry, SamPredictor
from skimage import measure
from skimage.transform import rotate, resize
from skimage.morphology import opening, disk
import pandas as pd

# Function to clean the mask
def clean_mask(mask, radius=5):
    structuring_element = disk(radius)
    return opening(mask, structuring_element)

# Function to center the mask
def center_mask(mask, desired_size):
    centered_mask = np.zeros((desired_size, desired_size), dtype=mask.dtype)
    mask_height, mask_width = mask.shape
    labeled_mask = measure.label(mask)
    props = measure.regionprops(labeled_mask)
    if len(props) == 0:
        return centered_mask

    centroid_row, centroid_col = props[0].centroid
    offset_row = int(desired_size / 2 - centroid_row)
    offset_col = int(desired_size / 2 - centroid_col)

    for row in range(mask_height):
        for col in range(mask_width):
            new_row = row + offset_row
            new_col = col + offset_col
            if 0 <= new_row < desired_size and 0 <= new_col < desired_size:
                centered_mask[new_row, new_col] = mask[row, col]
    return centered_mask

# Function to ensure scaling consistency
def resize_to_standard(mask, target_major_axis_length=100):
    labeled_mask = measure.label(mask)
    props = measure.regionprops(labeled_mask)
    if len(props) == 0:
        return mask

    major_axis_length = props[0].major_axis_length
    scale_factor = max(target_major_axis_length / major_axis_length, 1e-5)  # Prevent zero scaling
    new_shape = (int(mask.shape[0] * scale_factor), int(mask.shape[1] * scale_factor))
    resized_mask = resize(mask.astype(np.float32), new_shape, preserve_range=True, anti_aliasing=False) > 0.5

    # Ensure resized mask fits into a standardized size
    desired_size = max(new_shape[0], new_shape[1])
    desired_size = int(np.ceil(desired_size / 64.0)) * 64  # Round up to nearest multiple of 64
    return center_mask(resized_mask, desired_size)

# Function to process a single mask
def process_mask(image_path, input_dir, output_dir, target_major_axis_length=100):
    measurements = []
    try:
        # Load image
        image = np.array(Image.open(image_path).convert("RGB"))

        # YOLO detection
        yolo_results = yolo_model.predict(source=image_path, conf=0.25)
        if len(yolo_results[0].boxes) == 0:
            return []

        bbox = yolo_results[0].boxes.xyxy.tolist()[0]

        # SAM segmentation
        predictor.set_image(image)
        masks, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=np.array(bbox)[None, :],
            multimask_output=False,
        )
        segmentation_mask = masks[0]

        # Invert and clean the mask
        #cleaned_mask = clean_mask(~segmentation_mask.astype(bool))

        ##non inverted
        cleaned_mask = clean_mask(segmentation_mask.astype(bool))

        # Rotate the mask
        labeled_mask = measure.label(cleaned_mask)
        props = measure.regionprops(labeled_mask)
        if len(props) == 0:
            return []

        rotation_angle = -props[0].orientation * (180.0 / np.pi)
        rotated_mask = rotate(cleaned_mask, angle=rotation_angle, resize=True, preserve_range=True) > 0.5

        # Resize to standard major axis length
        standardized_mask = resize_to_standard(rotated_mask, target_major_axis_length)

        # Measure properties of the standardized mask
        centered_props = measure.regionprops(measure.label(standardized_mask))
        if len(centered_props) == 0:
            return []

        props = centered_props[0]
        measurements.append({
            'image_path': os.path.relpath(image_path, input_dir),
            'area': props.area,
            'perimeter': props.perimeter,
            'eccentricity': props.eccentricity,
            'solidity': props.solidity,
            'extent': props.extent,
            'major_axis_length': target_major_axis_length,  # Enforced length
            'minor_axis_length': props.minor_axis_length,
        })

        # Save the standardized mask
        output_path = os.path.join(output_dir, os.path.relpath(image_path, input_dir)).replace('.jpg', '_mask.png').replace('.jpeg', '_mask.png').replace('.png', '_mask.png')
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        Image.fromarray((standardized_mask * 255).astype(np.uint8)).save(output_path)

    except Exception as e:
        pass

    return measurements

# Load models
yolo_model = YOLO('runs/keyhole/weights/keyhole_yolo_model.pt')
sam_checkpoint = '../notebooks/sam_vit_h_4b8939.pth' ### NB. this is downloaded automatically on first use ###
model_type = 'vit_h'
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).cuda()
predictor = SamPredictor(sam)

# Paths
input_dir = 'FV' ##etc
output_dir = 'processed_masks_for_keyhole' ##etc

# Collect image paths
image_paths = []
for dirpath, _, filenames in os.walk(input_dir):
    for image_name in filenames:
        if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_paths.append(os.path.join(dirpath, image_name))

# Process all images
all_measurements = []
for image_path in image_paths:
    all_measurements.extend(process_mask(image_path, input_dir, output_dir))

# Save measurements to CSV
measurements_df = pd.DataFrame(all_measurements)
measurements_df.to_csv('keyhole_measurements.csv', index=False)
