# Understanding which segments contain most information

In [1]:
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

# Root folder containing the segmentation classes
ROOT_DIR = "/home/andreafabbricatore/rainbot/pre_processing/segmented"  # <- CHANGE THIS

def compute_mask_ratio(image_path):
    img = Image.open(image_path).convert("L")  # grayscale
    arr = np.array(img)
    total_pixels = arr.size
    non_black_pixels = np.count_nonzero(arr)
    return non_black_pixels / total_pixels

segment_ratios = {}

segment_classes = sorted([d for d in os.listdir(ROOT_DIR) if os.path.isdir(os.path.join(ROOT_DIR, d))])

for segment_class in tqdm(segment_classes, desc="Segment Classes"):
    class_path = os.path.join(ROOT_DIR, segment_class)

    ratios = []
    country_dirs = sorted([d for d in os.listdir(class_path) if os.path.isdir(os.path.join(class_path, d))])

    for country in country_dirs:
        country_path = os.path.join(class_path, country)

        image_files = sorted([f for f in os.listdir(country_path) if f.endswith(".png")])
        print(f"Country: {country}. Images: {len(image_files)}")
        for file in image_files:
            image_path = os.path.join(country_path, file)
            try:
                ratio = compute_mask_ratio(image_path)
                ratios.append(ratio)
            except Exception as e:
                print(f"Error reading {image_path}: {e}")

    if ratios:
        avg_ratio = np.mean(ratios)
        segment_ratios[segment_class] = avg_ratio

# Sort and print results
print("\n--- Segment Scores ---")
sorted_segments = sorted(segment_ratios.items(), key=lambda x: x[1], reverse=True)
for segment, score in sorted_segments:
    print(f"{segment}: {score:.4f}")

Segment Classes:   0%|          | 0/17 [00:00<?, ?it/s]

Country: New Zealand. Images: 132
Country: United Kingdom. Images: 697
Country: United States. Images: 456


Segment Classes:   6%|▌         | 1/17 [00:03<00:54,  3.41s/it]

Country: New Zealand. Images: 1270
Country: United Kingdom. Images: 4176
Country: United States. Images: 3810


Segment Classes:  12%|█▏        | 2/17 [00:38<05:26, 21.78s/it]

Country: New Zealand. Images: 260
Country: United Kingdom. Images: 499
Country: United States. Images: 484


Segment Classes:  18%|█▊        | 3/17 [00:41<03:07, 13.36s/it]

Country: New Zealand. Images: 1111
Country: United Kingdom. Images: 3773


Segment Classes:  18%|█▊        | 3/17 [00:47<03:43, 15.99s/it]


KeyboardInterrupt: 

In [2]:
def delete_black_images(root_dir):
    """
    Delete images that have more than threshold% black pixels
    """
    deleted_count = 0
    total_count = 0
    
    for segment_class in tqdm(os.listdir(root_dir), desc="Processing classes"):
        class_path = os.path.join(root_dir, segment_class)
        if not os.path.isdir(class_path):
            continue
            
        for country in os.listdir(class_path):
            country_path = os.path.join(class_path, country)
            if not os.path.isdir(country_path):
                continue
                
            for file in os.listdir(country_path):
                if not file.endswith('.png'):
                    continue
                    
                image_path = os.path.join(country_path, file)
                total_count += 1
                
                try:
                    ratio = compute_mask_ratio(image_path)
                    if ratio <0.05:  # If more than 95% is black
                        os.remove(image_path)
                        deleted_count += 1
                except Exception as e:
                    print(f"Error processing {image_path}: {e}")
    
    print(f"\nDeleted {deleted_count} out of {total_count} images")
    print(f"Remaining images: {total_count - deleted_count}")

# Execute the deletion
delete_black_images(ROOT_DIR)


Processing classes: 100%|██████████| 17/17 [05:45<00:00, 20.33s/it]


Deleted 59851 out of 93308 images
Remaining images: 33457





In [3]:
def analyze_segment_coverage():
    # Initialize scores dictionary
    image_scores = {}
    
    # Get all original images
    original_images_dir = "/home/andreafabbricatore/rainbot/pre_processing/datasets/extra_data"
    segmented_dir = "/home/andreafabbricatore/rainbot/pre_processing/segmented"
    
    # First, initialize scores for all original images
    for country in os.listdir(original_images_dir):
        country_path = os.path.join(original_images_dir, country)
        if not os.path.isdir(country_path):
            continue
            
        for image_file in os.listdir(country_path):
            if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                base_name = os.path.splitext(image_file)[0]
                image_scores[f"{country}/{base_name}"] = 0
    
    # Check each segment class folder
    for segment_class in os.listdir(segmented_dir):
        segment_path = os.path.join(segmented_dir, segment_class)
        if not os.path.isdir(segment_path):
            continue
            
        # Check each country in the segment class
        for country in os.listdir(segment_path):
            country_path = os.path.join(segment_path, country)
            if not os.path.isdir(country_path):
                continue
                
            # Check each segmented image
            for segment_file in os.listdir(country_path):
                if not segment_file.endswith('.png'):
                    continue
                    
                # Extract original image name from segment filename
                # Format is: original_name_segmentname.png
                original_name = '_'.join(segment_file.split('_')[:-1])
                
                # Increment score for this original image
                key = f"{country}/{original_name}"
                if key in image_scores:
                    image_scores[key] += 1
    
    # Sort images by score in descending order
    sorted_scores = sorted(image_scores.items(), key=lambda x: x[1], reverse=True)
    
    # Print results
    print("\nImage Coverage Analysis:")
    print("-----------------------")
    for image_path, score in sorted_scores:
        print(f"{image_path}: {score} segments")
    
    return sorted_scores

# Execute the analysis
segment_coverage = analyze_segment_coverage()



Image Coverage Analysis:
-----------------------
United States/3477: 7 segments
United States/0695: 6 segments
United States/3187: 6 segments
United States/3973: 6 segments
United States/4162: 6 segments
United States/3594: 6 segments
United Kingdom/2507: 6 segments
United Kingdom/0899: 6 segments
United Kingdom/4612: 6 segments
United Kingdom/4750: 6 segments
United Kingdom/3910: 6 segments
United Kingdom/4418: 6 segments
United Kingdom/1486: 6 segments
United Kingdom/0228: 6 segments
United Kingdom/1274: 6 segments
United Kingdom/1768: 6 segments
United Kingdom/1786: 6 segments
United Kingdom/3864: 6 segments
United Kingdom/0977: 6 segments
United Kingdom/3249: 6 segments
United Kingdom/3226: 6 segments
New Zealand/0329: 5 segments
New Zealand/0955: 5 segments
New Zealand/0366: 5 segments
New Zealand/0192: 5 segments
New Zealand/0830: 5 segments
New Zealand/1681: 5 segments
New Zealand/1677: 5 segments
New Zealand/1126: 5 segments
New Zealand/0822: 5 segments
New Zealand/0364: 5 seg

In [8]:
segment_coverage

[('United States/3477', 7),
 ('United States/0695', 6),
 ('United States/3187', 6),
 ('United States/3973', 6),
 ('United States/4162', 6),
 ('United States/3594', 6),
 ('United Kingdom/2507', 6),
 ('United Kingdom/0899', 6),
 ('United Kingdom/4612', 6),
 ('United Kingdom/4750', 6),
 ('United Kingdom/3910', 6),
 ('United Kingdom/4418', 6),
 ('United Kingdom/1486', 6),
 ('United Kingdom/0228', 6),
 ('United Kingdom/1274', 6),
 ('United Kingdom/1768', 6),
 ('United Kingdom/1786', 6),
 ('United Kingdom/3864', 6),
 ('United Kingdom/0977', 6),
 ('United Kingdom/3249', 6),
 ('United Kingdom/3226', 6),
 ('New Zealand/0329', 5),
 ('New Zealand/0955', 5),
 ('New Zealand/0366', 5),
 ('New Zealand/0192', 5),
 ('New Zealand/0830', 5),
 ('New Zealand/1681', 5),
 ('New Zealand/1677', 5),
 ('New Zealand/1126', 5),
 ('New Zealand/0822', 5),
 ('New Zealand/0364', 5),
 ('New Zealand/0257', 5),
 ('New Zealand/1227', 5),
 ('New Zealand/0325', 5),
 ('New Zealand/0023', 5),
 ('New Zealand/0725', 5),
 ('New 

In [9]:
# Create a new directory for filtered images
import shutil
import os

filtered_dir = os.path.join(ROOT_DIR, "filtered_images")
os.makedirs(filtered_dir, exist_ok=True)

# Group images by country
country_images = {}
for image_path, score in segment_coverage:
    country = image_path.split('/')[0]
    if country not in country_images:
        country_images[country] = []
    country_images[country].append((image_path, score))

# Process each country
for country, images in country_images.items():
    # Sort images by score and take top 450
    top_images = sorted(images, key=lambda x: x[1], reverse=True)[:450]
    
    # Create country directory in filtered folder
    country_filtered_dir = os.path.join(filtered_dir, country)
    os.makedirs(country_filtered_dir, exist_ok=True)
    
    # Copy top images to filtered directory
    for image_path, _ in top_images:
        # Get the original image name
        # Define the image directory path
        IMAGE_DIR = "/home/andreafabbricatore/rainbot/pre_processing/datasets/extra_data"

        # Extract the image ID from the path
        image_id = image_path.split('/')[-1]
        
        # Construct source path with .jpg extension
        source_path = os.path.join(IMAGE_DIR, country, f"{image_id}.jpg")
        
        # Verify source file exists before copying
        if not os.path.exists(source_path):
            print(f"Warning: Source file not found: {source_path}")
            continue
            
        dest_path = os.path.join(country_filtered_dir, f"{image_id}.jpg")
        
        try:
            # Copy the image
            shutil.copy2(source_path, dest_path)
            print(f"Copied {image_path} to filtered directory")
        except Exception as e:
            print(f"Error copying {image_path}: {str(e)}")

print(f"\nFiltered images have been saved to: {filtered_dir}")


Copied United States/3477 to filtered directory
Copied United States/0695 to filtered directory
Copied United States/3187 to filtered directory
Copied United States/3973 to filtered directory
Copied United States/4162 to filtered directory
Copied United States/3594 to filtered directory
Copied United States/1536 to filtered directory
Copied United States/0397 to filtered directory
Copied United States/1612 to filtered directory
Copied United States/3340 to filtered directory
Copied United States/3291 to filtered directory
Copied United States/3641 to filtered directory
Copied United States/4775 to filtered directory
Copied United States/3934 to filtered directory
Copied United States/0216 to filtered directory
Copied United States/3188 to filtered directory
Copied United States/1035 to filtered directory
Copied United States/3602 to filtered directory
Copied United States/0321 to filtered directory
Copied United States/0172 to filtered directory
Copied United States/1967 to filtered di

In [None]:
import os
import shutil
import random
from pathlib import Path

# Set random seed for reproducibility
random.seed(42)

# Define split ratios
TRAIN_RATIO = 0.8
TEST_RATIO = 0.1
VAL_RATIO = 0.1

filtered_dir = "/home/andreafabbricatore/rainbot/pre_processing/datasets/extra_data/filtered_images"
# Create train, test, and val directories
for split in ['train', 'test', 'val']:
    split_dir = os.path.join(filtered_dir, split)
    os.makedirs(split_dir, exist_ok=True)

# Process each country
for country in os.listdir(filtered_dir):
    country_path = os.path.join(filtered_dir, country)
    
    # Skip if not a directory or if it's one of our split directories
    if not os.path.isdir(country_path) or country in ['train', 'test', 'val']:
        continue
    
    # Get all images for this country
    images = [f for f in os.listdir(country_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
    random.shuffle(images)
    
    # Calculate split indices
    n_images = len(images)
    train_end = int(n_images * TRAIN_RATIO)
    test_end = train_end + int(n_images * TEST_RATIO)
    
    # Split images
    train_images = images[:train_end]
    test_images = images[train_end:test_end]
    val_images = images[test_end:]
    
    # Create country subdirectories in each split
    for split in ['train', 'test', 'val']:
        country_split_dir = os.path.join(filtered_dir, split, country)
        os.makedirs(country_split_dir, exist_ok=True)
    
    # Copy images to their respective split directories
    for img_list, split in [(train_images, 'train'), (test_images, 'test'), (val_images, 'val')]:
        for img in img_list:
            src = os.path.join(country_path, img)
            dst = os.path.join(filtered_dir, split, country, img)
            shutil.copy2(src, dst)
    
    print(f"Processed {country}: {len(train_images)} train, {len(test_images)} test, {len(val_images)} val images")

print("\nDataset split complete!")
