In [None]:
import os
import yaml
import cv2
import matplotlib.pyplot as plt
import random

def load_yaml(yaml_path):
    """Loads a YAML file."""
    with open(yaml_path, 'r') as f:
        return yaml.safe_load(f)

def denormalize_bbox(bbox_norm, img_width, img_height):
    """Converts normalized bounding box coordinates to absolute pixel values."""
    x_center_norm, y_center_norm, width_norm, height_norm = bbox_norm
    x_center = x_center_norm * img_width
    y_center = y_center_norm * img_height
    width = width_norm * img_width
    height = height_norm * img_height
    
    x_min = int(x_center - width / 2)
    y_min = int(y_center - height / 2)
    x_max = int(x_center + width / 2)
    y_max = int(y_center + height / 2)
    
    return x_min, y_min, x_max, y_max

def visualize_dataset_sample(data_yaml_path, num_images_to_show=3, dataset_split='train'):
    """
    Visualizes a few sample images with their bounding boxes and class labels
    from an Ultralytics-formatted object detection dataset.

    Args:
        data_yaml_path (str): Path to the data.yaml file.
        num_images_to_show (int): Number of random images to display.
        dataset_split (str): The dataset split to use ('train', 'val', or 'test').
    """
    print(f"Attempting to visualize {num_images_to_show} images from the '{dataset_split}' split.")
    print(f"Loading dataset configuration from: {data_yaml_path}")

    try:
        config = load_yaml(data_yaml_path)
    except Exception as e:
        print(f"Error: Could not load or parse {data_yaml_path}. Exception: {e}")
        return

    class_names = config.get('names')
    if not class_names:
        print(f"Error: 'names' (class names dictionary/list) not found in {data_yaml_path}.")
        return
    print(f"Class names loaded: {class_names}")

    yaml_file_directory = os.path.dirname(os.path.abspath(data_yaml_path))
    
    # Determine the dataset root directory.
    # 'path' in YAML is the dataset root, typically relative to the YAML file's location.
    dataset_root_config_path = config.get('path') 
    if dataset_root_config_path:
        # If 'path' is absolute, use it. Otherwise, resolve it relative to the YAML file.
        if os.path.isabs(dataset_root_config_path):
            dataset_root_abs = dataset_root_config_path
        else:
            dataset_root_abs = os.path.abspath(os.path.join(yaml_file_directory, dataset_root_config_path))
    else:
        # If 'path' is not in YAML, assume train/val paths are relative to YAML dir or absolute.
        dataset_root_abs = yaml_file_directory 
        print("Warning: 'path' key not found in data.yaml. Assuming paths for dataset splits are absolute or relative to the YAML file's directory.")

    if dataset_root_config_path and not os.path.isdir(dataset_root_abs): # Only error if 'path' was specified and not found
        print(f"Error: Dataset root directory '{dataset_root_abs}' (derived from 'path: {dataset_root_config_path}') does not exist.")
        return
    print(f"Interpreted dataset root (or base for resolving paths): {dataset_root_abs}")
    
    # Get the path for the image directory of the chosen split (e.g., 'images/train' or 'train')
    image_dir_relative_or_abs = config.get(dataset_split)
    if not image_dir_relative_or_abs:
        print(f"Error: Path for dataset split '{dataset_split}' not found in {data_yaml_path}.")
        return

    # Construct the full image directory path
    if os.path.isabs(image_dir_relative_or_abs):
        full_image_dir = image_dir_relative_or_abs
    else:
        # If 'path' was defined, image_dir_relative_or_abs is relative to dataset_root_abs.
        # Otherwise, it's relative to yaml_file_directory (which is dataset_root_abs in that case).
        full_image_dir = os.path.join(dataset_root_abs, image_dir_relative_or_abs)
    
    full_image_dir = os.path.abspath(full_image_dir)
    print(f"Constructed full image directory: {full_image_dir}")

    # Infer the label directory path
    # Standard structure: replace 'images' with 'labels' in the path.
    # e.g., /path/to/dataset/images/train -> /path/to/dataset/labels/train
    path_parts = list(os.path.split(full_image_dir)) # Robustly split path
    full_label_dir = ""
    try:
        # Find the 'images' directory component and replace it with 'labels'
        # This assumes a structure like .../dataset_name/images/split_name
        # Or .../images/split_name if images is a top-level component in the path from YAML.
        temp_path_parts = full_image_dir.split(os.sep)
        found_images_segment = False
        for i in range(len(temp_path_parts) -1, -1, -1): # Iterate backwards
            if temp_path_parts[i].lower() == 'images':
                temp_path_parts[i] = 'labels'
                full_label_dir = os.sep.join(temp_path_parts)
                found_images_segment = True
                break
        if not found_images_segment: # If 'images' is not in the path, try making 'labels' a sibling of the image folder's parent
             # e.g. if full_image_dir = /data/my_set/train_imgs, labels could be /data/my_set/train_labels
             # This is a heuristic. A common structure is .../dataset_root/[images|labels]/split
             # If full_image_dir is '.../dataset_X/train', labels might be '.../dataset_X/labels_train' or '.../dataset_X/labels/train'
             # Simplest assumption if 'images' not in path: labels are in a parallel folder to image folder's direct parent.
             # e.g., if images are in dataset_root/train_images, labels are in dataset_root/train_labels
             # A more robust way for common ultralytics structure:
             # if dataset_root_abs/images/train, then dataset_root_abs/labels/train
             # if dataset_root_abs/train (and this is an image folder), then dataset_root_abs needs a labels folder
             if config.get('path'): # If 'path' (dataset_root) was defined
                 split_basename = os.path.basename(full_image_dir) # e.g. 'train'
                 # Check if full_image_dir is like dataset_root_abs/images/split
                 if os.path.basename(os.path.dirname(full_image_dir)).lower() == 'images':
                     # This case should have been caught by the loop above.
                     pass
                 else: # Assume full_image_dir is like dataset_root_abs/split (e.g. .../coco128/train)
                       # Then labels are dataset_root_abs/labels/split
                    full_label_dir = os.path.join(dataset_root_abs, 'labels', split_basename)

        if not full_label_dir: # Default fallback if above logic didn't set it
            raise ValueError("Could not determine standard label directory structure.")

    except Exception as e:
        print(f"Warning: Could not reliably infer label directory from image directory '{full_image_dir}'. Error: {e}")
        print("Will proceed without attempting to load bounding boxes unless a 'labels' directory is found by convention.")
        # Try a simple convention as a last resort if full_label_dir is still empty
        if not full_label_dir:
             parent = os.path.dirname(full_image_dir)
             base = os.path.basename(full_image_dir)
             full_label_dir = os.path.join(parent, base.replace("images", "labels")) # crude replace
             if not os.path.isdir(full_label_dir): # if that fails, try structure like dataset_root/labels/split_name
                full_label_dir = os.path.join(dataset_root_abs, 'labels', os.path.basename(full_image_dir))


    print(f"Attempting to use label directory: {full_label_dir}")

    if not os.path.isdir(full_image_dir):
        print(f"Error: Final constructed image directory '{full_image_dir}' does not exist.")
        return

    image_files = [f for f in os.listdir(full_image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    if not image_files:
        print(f"No images found in '{full_image_dir}'.")
        return
    print(f"Found {len(image_files)} images in '{full_image_dir}'.")

    labels_available = os.path.isdir(full_label_dir)
    if not labels_available:
         print(f"Warning: Label directory '{full_label_dir}' does not exist. Bounding boxes will not be shown.")

    selected_images = random.sample(image_files, min(num_images_to_show, len(image_files)))

    for image_name in selected_images:
        image_path = os.path.join(full_image_dir, image_name)
        label_name = os.path.splitext(image_name)[0] + '.txt'
        label_path = os.path.join(full_label_dir, label_name)

        img = cv2.imread(image_path)
        if img is None:
            print(f"Warning: Could not read image: {image_path}")
            continue
        
        img_height, img_width = img.shape[:2]
        img_display = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB) # For matplotlib

        annotations_found_for_this_image = False
        if labels_available and os.path.exists(label_path):
            try:
                with open(label_path, 'r') as f:
                    for line_num, line in enumerate(f):
                        parts = line.strip().split()
                        if len(parts) < 5:
                            print(f"Warning: Malformed line #{line_num+1} in {label_path}: '{line.strip()}' (expected 5+ values)")
                            continue
                        try:
                            class_id = int(float(parts[0])) # class_id can sometimes be float in files like 0.0
                            bbox_norm = list(map(float, parts[1:5])) # x_center, y_center, width, height
                        except ValueError:
                            print(f"Warning: Could not parse numerical values in line #{line_num+1} in {label_path}: '{line.strip()}'")
                            continue
                        
                        x_min, y_min, x_max, y_max = denormalize_bbox(bbox_norm, img_width, img_height)
                        
                        current_class_name = f"ID:{class_id}" # Default
                        if isinstance(class_names, dict): # Handles 'names: {0: name1, 1: name2}'
                            current_class_name = class_names.get(class_id, f"ID:{class_id}")
                        elif isinstance(class_names, list): # Handles 'names: [name1, name2]'
                            if 0 <= class_id < len(class_names):
                                current_class_name = class_names[class_id]
                            else:
                                print(f"Warning: Class ID {class_id} out of bounds for class_names list (len {len(class_names)}) in {label_path}.")
                        
                        cv2.rectangle(img_display, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) # Green box
                        cv2.putText(img_display, current_class_name, (x_min, y_min - 10 if y_min > 20 else y_min + 20), 
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                        annotations_found_for_this_image = True
            except Exception as e:
                print(f"Error reading or processing label file {label_path}: {e}")
        
        if labels_available and not annotations_found_for_this_image and not os.path.exists(label_path):
            print(f"Label file not found: {label_path}")
        elif not labels_available and not annotations_found_for_this_image:
            # This is expected if label directory was not found earlier.
            pass


        plt.figure(figsize=(12, 10))
        plt.imshow(img_display)
        plt.title(f"Image: {image_name} (from {dataset_split} set)\nLabels from: {label_path if annotations_found_for_this_image else 'N/A'}")
        plt.axis('off')
        plt.show()

# --- How to use: ---
# 1. Save the code above as a Python file (e.g., view_dataset.py).
# 2. Make sure you have the necessary libraries: pip install pyyaml opencv-python matplotlib
# 3. Call the function with the path to your data.yaml file.

# Example:
# If your data.yaml is in the same directory as the script:
# visualize_dataset_sample('data.yaml', num_images_to_show=5, dataset_split='train')

# Or provide the full path:
# visualize_dataset_sample('/path/to/your/dataset/data.yaml', num_images_to_show=5, dataset_split='val')

# To use with your provided training script context:
# Assuming your data.yaml is correctly configured and accessible.
# Add this to a new script or at the end of your existing one (outside main guard):
if __name__ == '__main__':
    # --- IMPORTANT ---
    # Replace 'path/to/your/data.yaml' with the actual path to your data.yaml file.
    # Your training script uses data="data.yaml", so if this script is in the same 
    # directory as data.yaml and your dataset, you can just use 'data.yaml'.
    data_yaml_file = "/blue/hulcr/gmarais/PhD/phase_1_data/3_classification_phase_2/ultralytics/cv_iteration_1/data.yaml" # Or the correct path to your data.yaml

    print("\n--- Important Note on Dataset Type ---")
    print("You mentioned your dataset is for 'image classification'.")
    print("Standard Ultralytics image classification datasets are structured by class folders")
    print("(e.g., dataset/train/class1/image.jpg) and do NOT use .txt files for bounding box annotations.")
    print("Bounding boxes are typically used for 'object detection' datasets.")
    print("This script is designed to visualize bounding boxes from .txt label files common in object detection.")
    print("If your dataset is purely for classification without .txt label files, it will display images without bounding boxes.")
    print("-------------------------------------\n")

    visualize_dataset_sample(data_yaml_file, num_images_to_show=10, dataset_split='train')
    # You can also try the 'val' split:
    # visualize_dataset_sample(data_yaml_file, num_images_to_show=3, dataset_split='val')