In [None]:
# Reading 3 zipped datasets from Google Drive and unzipping them

import zipfile
from pathlib import Path
import gdown
import yaml
import shutil # Import shutil for file copying

# Needs to be updated with the most up-to-date datasets

datasets = {
    '1hb-7KYjd_H-KPprpB7Pv4-fO2aKvET6m':'AquaTrash_yolo',
    '1Z9xElXVKoj62XCrYz2ZmRpzBh8Rgy2ue':'mju_waste_yolo',
    '1-2imLxXKszSvYGXkmrz3kbDfxHh2C5VR':'TACO_yolo'
}

# === Root paths ===

project_root = Path('/content')
zip_dir = project_root/'zips'
zip_dir.mkdir(parents=True, exist_ok=True)
downloaded_zip_files = []
extracted_yolo_dataset_paths = []

for file_id, friendly_name in datasets.items():
    zip_path = zip_dir/f"{friendly_name}.zip"
    extract_dir = project_root/friendly_name
    yaml_path = extract_dir/'data.yaml'

    # === Download ===

    url = f"https://drive.google.com/uc?id={file_id}"
    print(f"Downloading {friendly_name} from {url} ...")
    gdown.download(url, str(zip_path), quiet=False)
    downloaded_zip_files.append(zip_path.name)

    # === Extract ===

    print(f"Extracting to {extract_dir} ...")
    # Ensure the target extraction directory exists
    extract_dir.mkdir(parents=True, exist_ok=True)

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        members = zip_ref.namelist()

        top_level_dir = None
        if members:
            first_member = members[0]
            if '/' in first_member and first_member.endswith('/'):
                top_level_dir = first_member.split('/')[0] + '/'

        for member in members:
            if top_level_dir and member.startswith(top_level_dir):
                extracted_path = member[len(top_level_dir):]
            else:
                extracted_path = member

            if extracted_path:
                target_path = extract_dir/extracted_path
                target_path.parent.mkdir(parents=True, exist_ok=True)
                if not member.endswith('/'):
                    with zip_ref.open(member) as source, open(target_path, "wb") as target:
                        shutil.copyfileobj(source, target)

    extracted_yolo_dataset_paths.append(str(extract_dir))

    # === Update data.yaml ===

    if yaml_path.exists():
        print(f"Updating path in {yaml_path} ...")
        with yaml_path.open('r') as f:
            data = yaml.safe_load(f)
        data['path'] = str(extract_dir)
        with yaml_path.open('w') as f:
            yaml.dump(data, f)
        print(f"Updated 'path' in data.yaml to: {extract_dir}\n")
    else:
        print(f"Warning: data.yaml not found in {extract_dir}\n")

# === Summary ===

print("\nDownloaded ZIP files:")
for name in downloaded_zip_files:
    print(f"- {name}")
print("\nExtracted dataset folders:")
for path in extracted_yolo_dataset_paths:
    print(f"- {path}")

In [None]:
"""
-Shuffles all image and labels from all 3 datasets and does another global train, val, test split once all 3 are merged.
-Keeps dataset name prefixes in all image and label names as in their original filenames.
-Creates a .yaml file for the merged dataset.
- Skips images with no label files.

"""
import os
import shutil
import yaml
import random

# --- Configuration ---

# Output directory for the merged YOLO dataset
merged_output_dir = '/content/merged_yolo_dataset_reshuffled' # Using a new directory name to avoid conflicts

# Define split ratios
# Ensure these variables are defined before being used
train_split_ratio = 0.8
val_split_ratio = 0.1
test_split_ratio = 0.1 # The remaining percentage

# Create the main merged directories
merged_images_base = os.path.join(merged_output_dir, 'images')
merged_labels_base = os.path.join(merged_output_dir, 'labels')

# Create train, val, test subdirectories within the merged structure
for subdir in ['train', 'val', 'test']:
    os.makedirs(os.path.join(merged_images_base, subdir), exist_ok=True)
    os.makedirs(os.path.join(merged_labels_base, subdir), exist_ok=True)

# Keep track of encountered class names and map them to new indices
# Since we are mapping to a single 'trash' class, these are less critical for the final output labels,
# but are kept for information during collection.
class_names_set = set()
merged_class_names_info = [] # To store original class names encountered
class_name_to_id = {} # This will only contain {'trash': 0} in the end

# List to store all image-label file pairs
all_files = [] # Initialize all_files here

print("Collecting all image and label file paths...")

# Make sure 'extracted_yolo_dataset_paths' is defined from a previous cell or execution.
# If not, you'll get a NameError for 'extracted_yolo_dataset_paths'.
# Assuming it is defined from the cell above (ipython-input-10).

for dataset_path in extracted_yolo_dataset_paths:
    print(f"Processing dataset: {dataset_path}")

    # Extract a simple prefix from the dataset path (kept for potential debugging/info, not used for renaming)
    dataset_prefix = os.path.basename(dataset_path).replace('_extracted', '').replace('-', '_') + '_'

    # --- Process Class Names from data.yaml if available ---
    data_yaml_path = os.path.join(dataset_path, 'data.yaml')
    current_class_id_to_name = {} # Store the original class map for this dataset

    if os.path.exists(data_yaml_path):
        try:
            with open(data_yaml_path, 'r') as f:
                current_data_yaml = yaml.safe_load(f)
                current_class_names = current_data_yaml.get('names', [])
                current_class_id_to_name = {i: name for i, name in enumerate(current_class_names)}

            # Collect original class names encountered
            for original_class_name in current_class_names:
                if original_class_name not in class_names_set:
                    class_names_set.add(original_class_name)
                    merged_class_names_info.append(original_class_name) # Store original names
            print(f"  Found {len(current_class_names)} original classes in data.yaml.")

        except FileNotFoundError:
             print(f"  Warning: data.yaml not found at {data_yaml_path}. Cannot get class names from this dataset.")
        except Exception as e:
             print(f"  Error loading data.yaml from {data_yaml_path}: {e}. Skipping class processing for this dataset.")
    else:
         print(f"  Warning: data.yaml not found at {data_yaml_path}. Cannot get class names from this dataset.")


    # --- Collect all image and label pairs from all available splits ---
    for split_subdir in ['train', 'val', 'test']: # Iterate through potential splits
        original_images_path = os.path.join(dataset_path, 'images', split_subdir)
        original_labels_path = os.path.join(dataset_path, 'labels', split_subdir)

        if os.path.exists(original_images_path) and os.path.exists(original_labels_path):
            print(f"  Collecting files from {split_subdir} split...")
            for filename in os.listdir(original_images_path):
                img_name, img_ext = os.path.splitext(filename)
                label_filename = img_name + '.txt'
                original_image_file = os.path.join(original_images_path, filename)
                original_label_file = os.path.join(original_labels_path, label_filename)

                # Check if corresponding label file exists
                if os.path.exists(original_label_file):
                    # Store the original paths and the dataset prefix for renaming later
                    all_files.append({
                        'original_img': original_image_file,
                        'original_label': original_label_file,
                        'dataset_prefix': dataset_prefix,
                        'original_class_id_to_name': current_class_id_to_name # Store the class map for this dataset
                    })
                else:
                    print(f"    Warning: Label file not found for image {filename} at {original_labels_path}. Skipping this image and its missing label.")
        else:
             print(f"  Info: Skipping {split_subdir} for this dataset as directory does not exist: {original_images_path}")


print(f"\nTotal image-label pairs collected: {len(all_files)}")

# --- Apply Global Train/Val/Test Split ---
print("Applying global train/val/test split...")

random.shuffle(all_files) # Shuffle the collected files NOW that it's populated

n_total = len(all_files)

# Calculate the number of files for each split based on the desired ratios
# Make sure train_split_ratio and val_split_ratio are defined
n_train = int(n_total * train_split_ratio)
n_val = int(n_total * val_split_ratio) # Use the ratio for validation
n_test = n_total - n_train - n_val # The rest goes to test

train_files = all_files[:n_train]
val_files = all_files[n_train:n_train + n_val]
test_files = all_files[n_train + n_val:]

print(f"Split distribution: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}")

# --- Copy and Process Files to Merged Directory ---
print("Copying and processing files to merged dataset...")

file_splits = {'train': train_files, 'val': val_files, 'test': test_files}

for split_name, files_list in file_splits.items():
    print(f"Processing {split_name} split...")
    merged_images_split_path = os.path.join(merged_images_base, split_name)
    merged_labels_split_path = os.path.join(merged_labels_base, split_name)

    for file_info in files_list:
        original_img_path = file_info['original_img']
        original_label_path = file_info['original_label']
        # dataset_prefix and original_class_id_to_name are available but not strictly needed for label processing with single class

        # Use original filenames for copying
        img_filename = os.path.basename(original_img_path)
        label_filename = os.path.basename(original_label_path)

        merged_dest_image_path = os.path.join(merged_images_split_path, img_filename)
        merged_dest_label_path = os.path.join(merged_labels_split_path, label_filename)

        try:
            # Copy files
            shutil.copy2(original_img_path, merged_dest_image_path)
            shutil.copy2(original_label_path, merged_dest_label_path)

            # --- Update Class Indices in Label File ---
            # This section is modified to map all original classes to the single 'trash' class (ID 0)
            updated_lines = []
            # Use merged_dest_label_path here to open the newly copied label file
            with open(merged_dest_label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts and len(parts) == 5:
                        try:
                            # We don't need the original_class_id value itself, just the format
                            # original_class_id = int(parts[0]) # This line is no longer strictly needed

                            # Since the merged dataset has only one class 'trash' with ID 0,
                            # map all original class IDs to 0.
                            new_class_id = 0

                            parts[0] = str(new_class_id)
                            updated_lines.append(" ".join(parts))

                        except ValueError:
                            # Use merged_dest_label_path here for informative warning
                            print(f"    Warning: Invalid class ID format in label file {os.path.basename(merged_dest_label_path)} line: '{line.strip()}'. Skipping this line.")
                            pass # Skip lines with invalid format
                        except Exception as line_e:
                            # Use merged_dest_label_path here for informative warning
                            print(f"    Error processing line '{line.strip()}' in {os.path.basename(merged_dest_label_path)}: {line_e}. Skipping this line.")
                            pass # Skip lines with other errors
                    else:
                        # Use merged_dest_label_path here for informative warning
                        print(f"    Warning: Skipping malformed line in {os.path.basename(merged_dest_label_path)}: '{line.strip()}'")


            # Write the updated lines back to the label file
            # Use merged_dest_label_path here to write to the newly copied label file
            with open(merged_dest_label_path, 'w') as f:
                f.write("\n".join(updated_lines))

        except FileNotFoundError:
             print(f"    Error copying files for {img_filename}. Image or label file not found unexpectedly.")
        except Exception as e:
             print(f"    Error processing file {img_filename}: {e}") # Keep img_filename here as it's the identifier


print("YOLO dataset merging complete.")

# --- Create Merged data.yaml ---
print("Creating merged data.yaml...")

# Ensure the class name list contains only 'trash'
final_merged_class_names = ['trash']
# Update the class_name_to_id dictionary to reflect the final state
class_name_to_id = {'trash': 0}


merged_data_yaml = {
    'path': merged_output_dir,
    'train': os.path.join(merged_images_base, 'train'),
    'val': os.path.join(merged_images_base, 'val'),
    'test': os.path.join(merged_images_base, 'test'),
    'nc': 1,
    'names': ['trash']
}

# Save the merged data.yaml file
merged_data_yaml_path = os.path.join(merged_output_dir, 'data.yaml')
with open(merged_data_yaml_path, 'w') as f:
    yaml.dump(merged_data_yaml, f, sort_keys=False)

print(f"Merged data.yaml saved to: {merged_data_yaml_path}")
print(f"Total classes in merged dataset: {len(final_merged_class_names)}")
print(f"Merged class names: {final_merged_class_names}")

# --- Next Steps ---
print("\nNext Steps:")
print(f"1. Your reshuffled merged YOLO dataset is located at: {merged_output_dir}")
print(f"2. The data.yaml file is in: {merged_data_yaml_path}")
print("3. You can now use this merged dataset for training with YOLO.")

In [None]:
# Summarizing the % of images from all 3 datasets in train, val, test datasets
# %%
import os

# Path to your merged dataset directory
merged_output_dir = '/content/merged_yolo_dataset_reshuffled' # Ensure this matches your output directory name if you changed it in the merging step

# Define characteristic filename patterns for each original dataset
# We'll use these to identify which dataset an image originally came from
dataset_patterns = {
    'mju_waste': 'mju_waste_', # mju-waste images seem to start with '20' (e.g., '2019-...')
    'TACO': 'TACO_',    # TACO images start with 'TACO_'
    'AquaTrash': 'AquaTrash_' # AquaTrash images start with 'AquaTrash_'
}

# Dictionary to store counts
# Structure: {split: {pattern_key: count}}
image_counts_by_dataset = {}

# Dictionary to store total counts per split
total_images_per_split = {}

print("\nCounting images from each original dataset in train, val, and test splits...")

total_image_label_pairs_collected = 0 # Initialize a counter for total images across all splits

# Iterate through each split (train, val, test)
for split_subdir in ['train', 'val', 'test']:
    split_images_path = os.path.join(merged_output_dir, 'images', split_subdir)
    image_counts_by_dataset[split_subdir] = {key: 0 for key in dataset_patterns.keys()} # Initialize counts for this split
    total_images_per_split[split_subdir] = 0 # Initialize total for this split

    if os.path.exists(split_images_path):
        print(f"Processing split: {split_subdir}")
        # List all files in the image directory for the current split
        for filename in os.listdir(split_images_path):
            # Check if the file is an image (you might want to add more robust checks for image extensions)
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                # Try to match the filename with a known dataset pattern
                matched_key = None
                for key, pattern in dataset_patterns.items():
                    if filename.startswith(pattern):
                        matched_key = key
                        break # Found a match, no need to check other patterns

                if matched_key:
                    # Increment the count for the matched pattern in the current split
                    image_counts_by_dataset[split_subdir][matched_key] += 1
                    total_images_per_split[split_subdir] += 1 # Increment total for this split
                    total_image_label_pairs_collected += 1 # Increment overall total
                else:
                    # Handle files that don't match any expected pattern
                    print(f"  Warning: Image '{filename}' in {split_subdir} does not match any known dataset pattern. It will not be included in the per-dataset counts.")
    else:
        print(f"  Info: Image directory for {split_subdir} not found at {split_images_path}")

# Print the total number of image-label pairs collected across all splits
print(f"\nTotal image-label pairs collected across all merged splits: {total_image_label_pairs_collected}")

# Print the results per dataset per split
print("\n--- Image Counts Summary (per original dataset within each split) ---")
for split_name, pattern_counts in image_counts_by_dataset.items():
    print(f"\nSplit: {split_name}")
    total_in_split = total_images_per_split[split_name] # Use the total calculated during the file scan
    if total_in_split > 0:
        for key, count in pattern_counts.items():
            # Use the key for display name
            display_name = key
            print(f"  - {display_name}: {count} images ({count/total_in_split:.1%})")
        print(f"  Total images in {split_name} split: {total_in_split}")
    else:
        print(f"  No images found in the {split_name} split.")

# Print the overall train/val/test split percentages
print("\n--- Overall Train/Val/Test Split Summary ---")

if total_image_label_pairs_collected > 0:
    for split_name, count in total_images_per_split.items():
        percentage = (count / total_image_label_pairs_collected) * 100 if total_image_label_pairs_collected > 0 else 0
        print(f"  - {split_name}: {count} images ({percentage:.1f}%)")
else:
    print("No images found in the merged dataset to calculate split percentages.")

print("\n--- End of Summary ---")

In [None]:
"""# saving the merged_yolo-dataset_reshuffle as a zip file to be downloaded

import shutil
import os
from google.colab import files

# Specify the directory you want to zip
directory_to_zip = '/content/merged_yolo_dataset_reshuffled' # Replace with the actual path to your directory

# Specify the name for the output zip file
output_zip_file = 'merged_yolo_dataset_reshuffled.zip' # Replace with your desired zip file name

# Use shutil.make_archive to create the zip file
# The first argument is the base name of the archive (without extension)
# The second argument is the archive format ('zip')
# The third argument is the directory to archive
shutil.make_archive(output_zip_file.replace('.zip', ''), 'zip', directory_to_zip)

print(f"Created zip file: {output_zip_file}")

# Download the zip file
try:
    files.download(output_zip_file)
    print(f"Downloading {output_zip_file}...")
except FileNotFoundError:
    print(f"Error: The file {output_zip_file} was not found after zipping.")
except Exception as e:
    print(f"An error occurred during download: {e}")"""

In [None]:
#Inspecting a batch of x image-label pairs at a time

# %%
import os
import random
import yaml
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
# Path to your merged dataset directory
merged_output_dir = '/content/merged_yolo_dataset_reshuffled' # Ensure this matches your output directory

# List of splits to check
splits_to_check = ['train', 'val', 'test']

# --- Visualization Limit ---
images_per_split_to_display = 5 # Set the maximum number of images to display per split initially
start_index = 0 # Set the starting index for the batch (0 for the first 50, 50 for the next, etc.)

# --- Load Class Names from merged data.yaml ---
merged_data_yaml_path = os.path.join(merged_output_dir, 'data.yaml')
merged_class_names = []

if os.path.exists(merged_data_yaml_path):
    try:
        with open(merged_data_yaml_path, 'r') as f:
            merged_data_yaml = yaml.safe_load(f)
            merged_class_names = merged_data_yaml.get('names', [])
        print(f"Loaded {len(merged_class_names)} class names from data.yaml: {merged_class_names}")
    except Exception as e:
        print(f"Error loading merged data.yaml from {merged_data_yaml_path}: {e}")
        merged_class_names = [] # Fallback to empty list if loading fails
else:
    print(f"Error: Merged data.yaml not found at {merged_data_yaml_path}. Cannot display class names.")
    merged_class_names = []

# --- Function to Draw Bounding Boxes and Display ---
def plot_yolo_boxes(image_path, label_path, class_names):
    try:
        image = Image.open(image_path).convert("RGB")
        draw = ImageDraw.Draw(image)
        width, height = image.size

        with open(label_path, 'r') as f:
            labels = f.readlines()

        # If no labels, just show the image and print a note
        if not labels or all(line.strip() == '' for line in labels):
            print(f"  Info: No annotations found in {os.path.basename(label_path)}. Displaying image only.")
            plt.figure(figsize=(10, 10))
            plt.imshow(image)
            plt.title(f"Image: {os.path.basename(image_path)}")
            plt.axis('off')
            plt.show()
            return # Exit function if no labels

        for label in labels:
            parts = label.strip().split()
            if len(parts) == 5:
                try:
                    class_id = int(parts[0])
                    # YOLO format: class_id center_x center_y width height (normalized)
                    center_x, center_y, box_width, box_height = map(float, parts[1:])

                    # Convert normalized coordinates to pixel coordinates
                    # top_left_x, top_left_y, bottom_right_x, bottom_right_y
                    x_min = int((center_x - box_width / 2) * width)
                    y_min = int((center_y - box_height / 2) * height)
                    x_max = int((center_x + box_width / 2) * width)
                    y_max = int((center_y + box_height / 2) * height)

                    # Ensure coordinates are within image bounds
                    x_min = max(0, x_min)
                    y_min = max(0, y_min)
                    x_max = min(width, x_max)
                    y_max = min(height, y_max)

                    # Get class name
                    class_name = class_names[class_id] if class_id < len(class_names) else f'Unknown Class ID: {class_id}'

                    # Draw rectangle
                    draw.rectangle([(x_min, y_min), (x_max, y_max)], outline="red", width=2)

                    # Draw class name text
                    try:
                        # Use a default font or try to load one
                        font_size = max(10, int(0.02 * height)) # Adjust font size based on image height
                        # Attempt to use a common font or load default
                        try:
                            font = ImageFont.truetype("arial.ttf", font_size)
                        except IOError:
                            try:
                                font = ImageFont.truetype("DejaVuSans.ttf", font_size)
                            except IOError:
                                font = ImageFont.load_default() # Fallback font
                    except Exception as font_e:
                         print(f"  Warning: Could not load specific font for {os.path.basename(image_path)}: {font_e}. Using default.")
                         font = ImageFont.load_default()


                    text = f"{class_name}"
                    # Calculate text bounding box for positioning
                    try:
                        text_bbox = draw.textbbox((0, 0), text, font=font)
                        text_width = text_bbox[2] - text_bbox[0]
                        text_height = text_bbox[3] - text_bbox[1]
                    except NotImplementedError:
                        # Fallback for older Pillow versions or complex text layouts
                        text_size = draw.textsize(text, font=font)
                        text_width, text_height = text_size
                        print(f"  Warning: Using fallback text size calculation for {os.path.basename(image_path)}. Text positioning might be less precise.")


                    # Position the text just above the bounding box
                    text_x = x_min
                    text_y = y_min - text_height - 2 # 2 pixels padding

                    # Ensure text is within image bounds
                    text_x = max(0, text_x)
                    text_y = max(0, text_y)
                    if text_x + text_width > width:
                         text_x = width - text_width # Move text left if it goes off right edge
                    if text_y < 0:
                         text_y = y_min + 2 # Place below if it goes above the image


                    draw.text((text_x, text_y), text, fill="red", font=font)

                except (ValueError, IndexError) as e:
                    print(f"  Warning: Skipping malformed or invalid line in label file {os.path.basename(label_path)}: '{label.strip()}' - Error: {e}")
                    continue # Skip this annotation if parsing fails
            else:
                 print(f"  Warning: Skipping malformed line in label file {os.path.basename(label_path)}: '{label.strip()}'")


        # Display the image using matplotlib
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        plt.title(f"Image: {os.path.basename(image_path)}")
        plt.axis('off')
        plt.show()

    except FileNotFoundError:
        print(f"  Error: Image or label file not found for {os.path.basename(image_path)}")
    except Exception as e:
        print(f"  An error occurred while processing {os.path.basename(image_path)}: {e}")


# --- Collect and Display a Batch of Samples per Split ---
print("\nCollecting and displaying a batch of image-label pairs from each split...")

for split_to_sample in splits_to_check:
    image_dir = os.path.join(merged_output_dir, 'images', split_to_sample)
    label_dir = os.path.join(merged_output_dir, 'labels', split_to_sample)

    image_label_pairs = []

    if os.path.exists(image_dir) and os.path.exists(label_dir):
        print(f"\nCollecting image-label pairs from {split_to_sample} split...")
        # Collect ALL image-label pairs for this split first
        for filename in os.listdir(image_dir):
            img_name, img_ext = os.path.splitext(filename)
            label_filename = img_name + '.txt'
            image_path = os.path.join(image_dir, filename)
            label_path = os.path.join(label_dir, label_filename)

            # Check if corresponding label file exists
            if os.path.exists(label_path):
                image_label_pairs.append({'image': image_path, 'label': label_path})
            # Optional: Uncomment if you want to see warnings for images with missing labels
            # else:
            #      print(f"  Warning: Label file not found for image {filename} in {split_to_sample}.")

        print(f"Found {len(image_label_pairs)} image-label pairs in {split_to_sample} split.")

        # --- Select and Display a Batch ---
        end_index = start_index + images_per_split_to_display
        batch_to_display = image_label_pairs[start_index:end_index]

        if len(batch_to_display) > 0:
            print(f"\nDisplaying batch of {len(batch_to_display)} samples (index {start_index} to {end_index-1}) from the {split_to_sample} split.")
            for sample in batch_to_display:
                plot_yolo_boxes(sample['image'], sample['label'], merged_class_names)
        else:
            print(f"No image-label pairs found in the {split_to_sample} split within the specified range (index {start_index} to {end_index-1}).")

    else:
        print(f"Error: Image or label directory not found for {split_to_sample} split.")
        print(f"Image dir: {image_dir}")
        print(f"Label dir: {label_dir}")

print("\nFinished visualization for the current batch.")

# --- How to View the Next Batch ---
print(f"\nTo view the next batch of {images_per_split_to_display} images:")
print(f"1. Change the `start_index` variable in the code cell to {start_index + images_per_split_to_display}.")
print("2. Re-run the code cell.")