In [None]:
import os
from pathlib import Path
import shutil
import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import random
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd

In [None]:
from source.visual_genome_meta_data import read_json_to_dict
from source.visual_genome_meta_data import get_image_meta_data
from source.visual_genome_meta_data import count_occurrences
from source.visual_genome_to_yolo import create_class_mapping_from_list
from source.visual_genome_to_yolo import save_class_map_to_yaml
from source.visual_genome_to_yolo import convert_single_image_to_yolo
from source.visual_genome_to_yolo import read_yaml_to_class_map
from source.visual_genome_to_yolo import read_yolo_metadata
from source.visual_genome_to_yolo import visual_genome_to_yolo_data_n
from source.visual_genome_meta_data import plot_image_with_multiple_bboxes
from source.visual_genome_meta_data import get_image_ids
from source.yolo_training_structure import distribute_train_val_files as dist_train_val
from source.visual_genome_data import get_file_by_id


In [None]:

def select_img_ids_with_obj(desired_objects, data_path, objects, image_ids_present, with_obj = True):
    
    occurrence_counts = dict.fromkeys(desired_objects, 0)
    
    selected_image_ids = []
    
    round_counter = 0
    image_counter = 0
    
    for idx in list(range(len(objects))):
    
        image_id = objects[idx]['image_id']
        #print(image_id)
        if image_id not in image_ids_present:
            #print('continue')
            continue
        #print(image_id)
        
        names = []
        for idx_obj in list(range(len(objects[idx]['objects']))):
            name = objects[idx]['objects'][idx_obj]['names']
            #print(name)
            names.extend(name)
        #print(names)
        
        inter_set = set(desired_objects).intersection(set(names))
    
        #count_occurrences(occurrence_counts, names)
        inter_set = set(desired_objects).intersection(set(names))
        if (len(inter_set) > 0) and with_obj: 
            count_occurrences(occurrence_counts, names)
            image_counter += 1
            selected_image_ids.append(image_id)
        elif (len(inter_set) == 0) and not with_obj: 
            count_occurrences(occurrence_counts, names)
            image_counter += 1
            selected_image_ids.append(image_id)
        
        round_counter += 1
        #print(round_counter)
        #print(round_counter > 7)
    return selected_image_ids


In [None]:
def load_image_paths(image_directory, extensions=('.jpg', '.jpeg', '.tif', '.tiff')):
    """
    Load all image file paths from a directory.
    
    Args:
        image_directory (str): Path to directory containing images
        extensions (tuple): Allowed file extensions
        
    Returns:
        list: Sorted list of image file paths
    """
    image_paths = []
    image_dir = Path(image_directory)
    
    if not image_dir.exists():
        raise FileNotFoundError(f"Directory {image_directory} does not exist")
    
    for ext in extensions:
        # Find files with current extension (case insensitive)
        image_paths.extend(image_dir.glob(f"*{ext}"))
        image_paths.extend(image_dir.glob(f"*{ext.upper()}"))
    
    # Sort paths to ensure consistent ordering
    image_paths = sorted([str(path) for path in image_paths])
    
    print(f"Found {len(image_paths)} images in {image_directory}")
    return image_paths

In [None]:
from pathlib import Path
from PIL import Image


def meets_size_requirements(image_path, minimum_side_length):
    """
    Check if image meets minimum size requirements.
    
    Args:
        image_path (str): Path to image file
        minimum_side_length (int): Minimum required dimension for both height and width
        
    Returns:
        tuple: (meets_requirements, width, height) or (False, None, None) if error
    """
    try:
        with Image.open(image_path) as img:
            width, height = img.size
            meets_req = width >= minimum_side_length and height >= minimum_side_length
            return meets_req, width, height
    except Exception as e:
        print(f"Error checking {image_path}: {e}")
        return False, None, None


def downsample_image(image, minimum_side_length):
    """
    Downsample image so that the smaller dimension equals minimum_side_length.
    Preserves aspect ratio.
    
    Args:
        image (PIL.Image): Input image
        minimum_side_length (int): Target size for smaller dimension
        
    Returns:
        PIL.Image: Downsampled image
    """
    width, height = image.size
    
    # If both dimensions are already equal to minimum_side_length, no downsampling needed
    if min(width, height) == minimum_side_length:
        return image
    
    # Calculate new dimensions maintaining aspect ratio
    if width < height:
        new_width = minimum_side_length
        new_height = int((height * minimum_side_length) / width)
    else:
        new_height = minimum_side_length
        new_width = int((width * minimum_side_length) / height)
    
    return image.resize((new_width, new_height), Image.Resampling.LANCZOS)


def center_crop_to_square(image, side_length):
    """
    Center crop image to square shape.
    
    Args:
        image (PIL.Image): Input image
        side_length (int): Side length of output square
        
    Returns:
        PIL.Image: Square cropped image
    """
    width, height = image.size
    
    # Calculate center crop coordinates
    left = (width - side_length) // 2
    top = (height - side_length) // 2
    right = left + side_length
    bottom = top + side_length
    
    return image.crop((left, top, right, bottom))


def generate_processed_filename(original_path, tag="proc"):
    """
    Generate new filename with processing tag while preserving identifier.
    
    Args:
        original_path (str): Original file path
        tag (str): Processing tag to insert
        
    Returns:
        str: New filename with tag
        
    Example: visual_genome_23.jpg -> visual_genome_proc_23.jpg
    """
    path_obj = Path(original_path)
    name_parts = path_obj.stem.split('_')
    
    if len(name_parts) >= 2:
        # Insert tag before the last part (identifier)
        identifier = name_parts[-1]
        base_parts = name_parts[:-1]
        new_name = '_'.join(base_parts + [tag, identifier])
    else:
        # If no underscore structure, just add tag before extension
        new_name = f"{path_obj.stem}_{tag}"
    
    return f"{new_name}{path_obj.suffix}"


def select_downsize_images(image_paths, output_directory, minimum_side_length):
    """
    Select, downsize, and crop images according to specifications.
    
    Args:
        image_paths (list): List of image file paths
        output_directory (str): Directory to save processed images
        minimum_side_length (int): Minimum side length for filtering and final square size
        
    Returns:
        list: Paths to processed images
    """
    # Create output directory if it doesn't exist
    output_path = Path(output_directory)
    output_path.mkdir(parents=True, exist_ok=True)
    
    if not image_paths:
        #print("No image paths provided")
        return []
    
    processed_paths = []
    selected_count = 0
    
    #print(f"Processing {len(image_paths)} images with minimum side length: {minimum_side_length}")
    #print(f"Output directory: {output_directory}")
    
    for i, image_path in enumerate(image_paths):
        # Step 1: Check size requirements
        meets_req, width, height = meets_size_requirements(image_path, minimum_side_length)
        
        if not meets_req:
            #if width is not None and height is not None:
            #    print(f"Skipping {Path(image_path).name}: size {width}x{height} too small")
            continue
        
        selected_count += 1
        
        try:
            # Load image
            with Image.open(image_path) as img:
                processed_img = img.copy()
            
            # Step 2: Downsample if needed
            if min(width, height) > minimum_side_length:
                processed_img = downsample_image(processed_img, minimum_side_length)
                #print(f"Downsampled {Path(image_path).name}: {width}x{height} -> {processed_img.size}")
            
            # Step 3: Center crop to square
            processed_img = center_crop_to_square(processed_img, minimum_side_length)
            
            # Step 4: Generate new filename and save
            new_filename = generate_processed_filename(image_path)
            output_path_full = output_path / new_filename
            
            # Save in original format
            processed_img.save(output_path_full)
            processed_paths.append(str(output_path_full))
            
            #print(f"Processed: {Path(image_path).name} -> {new_filename}")
            
        except Exception as e:
            #print(f"Error processing {image_path}: {e}")
            continue
        
        # Progress indicator
        if selected_count % 50 == 0:
            print(f"Processed {selected_count} images...")
    
    print(f"\nCompleted processing:")
    print(f"Total images found: {len(image_paths)}")
    print(f"Images meeting size requirements: {selected_count}")
    print(f"Successfully processed: {len(processed_paths)}")
    print(f"Final image size: {minimum_side_length}x{minimum_side_length}")
    
    return processed_paths

In [None]:
def analyze_image_sizes(image_paths, extensions=('.jpg', '.jpeg', '.tif', '.tiff')):
   """
   Analyze image sizes in a directory to help determine appropriate target_size.
   
   Args:
       image_paths (list): Paths to all images
       extensions (tuple): Allowed file extensions
       
   Returns:
       dict: Dictionary containing size analysis results
   """
   
   if not image_paths:
       print("No images found in directory")
       return None
   
   sizes = []
   widths = []
   heights = []
   failed_images = []
   
   print(f"Analyzing {len(image_paths)} images...")
   
   # Analyze each image
   for i, image_path in enumerate(image_paths):
       try:
           with Image.open(image_path) as img:
               width, height = img.size
               sizes.append((width, height))
               widths.append(width)
               heights.append(height)
       except Exception as e:
           failed_images.append((image_path, str(e)))
           print(f"Failed to read {image_path}: {e}")
       
       # Progress indicator for large datasets
       if (i + 1) % 100 == 0:
           print(f"Processed {i + 1}/{len(image_paths)} images...")
   
   if not sizes:
       print("No valid images found")
       return None
   
   # Calculate statistics
   unique_sizes = list(set(sizes))
   all_same_size = len(unique_sizes) == 1
   
   min_width = min(widths)
   max_width = max(widths)
   avg_width = sum(widths) / len(widths)
   
   min_height = min(heights)
   max_height = max(heights)
   avg_height = sum(heights) / len(heights)
   
   # Calculate quartiles
   import numpy as np
   width_q25, width_median, width_q75 = np.percentile(widths, [25, 50, 75])
   height_q25, height_median, height_q75 = np.percentile(heights, [25, 50, 75])
   
   min_size = (min_width, min_height)
   max_size = (max_width, max_height)
   avg_size = (avg_width, avg_height)
   
   # Create results dictionary
   results = {
       'total_images': len(image_paths),
       'valid_images': len(sizes),
       'failed_images': len(failed_images),
       'all_same_size': all_same_size,
       'unique_sizes_count': len(unique_sizes),
       'min_size': min_size,
       'max_size': max_size,
       'avg_size': avg_size,
       'min_width': min_width,
       'max_width': max_width,
       'avg_width': avg_width,
       'min_height': min_height,
       'max_height': max_height,
       'avg_height': avg_height,
       'width_q25': width_q25,
       'width_median': width_median,
       'width_q75': width_q75,
       'height_q25': height_q25,
       'height_median': height_median,
       'height_q75': height_q75,
       'failed_images': failed_images
   }
   
   # Print summary
   print("\n" + "="*50)
   print("IMAGE SIZE ANALYSIS SUMMARY")
   print("="*50)
   print(f"Total images found: {results['total_images']}")
   print(f"Valid images: {results['valid_images']}")
   print(f"Failed to read: {results['failed_images']}")
   print(f"\nAll images same size: {'Yes' if all_same_size else 'No'}")
   print(f"Number of unique sizes: {results['unique_sizes_count']}")
   
   print(f"\nSize ranges:")
   print(f"  Minimum size: {min_size[0]} x {min_size[1]}")
   print(f"  Maximum size: {max_size[0]} x {max_size[1]}")
   print(f"  Average size: {avg_size[0]:.1f} x {avg_size[1]:.1f}")
   
   print(f"\nWidth range: {min_width} - {max_width} (avg: {avg_width:.1f})")
   print(f"Width quartiles: Q25={width_q25:.1f}, Median={width_median:.1f}, Q75={width_q75:.1f}")
   print(f"Height range: {min_height} - {max_height} (avg: {avg_height:.1f})")
   print(f"Height quartiles: Q25={height_q25:.1f}, Median={height_median:.1f}, Q75={height_q75:.1f}")
   
   if results['failed_images']:
       print(f"\nFailed images:")
       for path, error in results['failed_images'][:5]:  # Show first 5 failures
           print(f"  {path}: {error}")
       if len(results['failed_images']) > 5:
           print(f"  ... and {len(results['failed_images']) - 5} more")
   
   # Suggest target size
   if all_same_size:
       print(f"\nRecommendation: Use target_size={min_size} (all images are the same size)")
   else:
       # Suggest a reasonable target size based on minimum dimensions
       suggested_size = min(min_width, min_height)
       # Round to common sizes
       common_sizes = [28, 32, 64, 128, 224, 256, 512]
       suggested_size = min(common_sizes, key=lambda x: abs(x - suggested_size))
       print(f"\nRecommendation: Consider target_size=({suggested_size}, {suggested_size})")
       print(f"  (Based on minimum dimension and common image sizes)")
   
   return results

In [None]:
def get_file_paths_by_id(data_path, identifiers, file_extension):
    file_paths = []
    for ids in identifiers:
        file_name = get_file_by_id(data_path, ids, file_extension)
        if len(file_name) > 1:
            print('more than one image')
        file_path = data_path / file_name[0]
        file_paths.append(file_path)
    return file_paths

### Define paths: 

In [None]:
#root_path = Path('/Users/stephanehess/Documents/CAS_AML/dias_digit_project')


project_path = Path.cwd()
root_path = (project_path / '..').resolve()

#root_path = Path('/Users/stephanehess/Documents/CAS_AML/dias_digit_project/test_yolo_object_train')
data_path = root_path / 'visual_genome_data'
data_proc_path = root_path / 'visual_genome_proc_data'
#data_path = root_path / 'visual_genome_data_all'
#yolo_path = root_path / 'visual_genome_yolo'
#yolo_path = root_path / 'visual_genome_yolo_all'

In [None]:
root_path

In [None]:
data_path

In [None]:
data_proc_path

In [None]:
image_paths = load_image_paths(data_path, '.jpg')

In [None]:
results = analyze_image_sizes(image_paths, extensions=('.jpg', '.jpeg', '.tif', '.tiff'))
results

### Define minimum required side length for processed images to be used later:

In [None]:
minimum_side_length = 320

### Select image paths of images with minimum_side_length or bigger, downsize them if necessary, and save images in data_proc_path

In [None]:


# Then preprocess using the paths directly
processed_paths = select_downsize_images(
    image_paths=image_paths,
    output_directory=data_proc_path,
    minimum_side_length=minimum_side_length
)

### Read in objects file with meta data about visual genome data: 

In [None]:
objects_file_path = data_path/'objects.json'


In [None]:
objects = read_json_to_dict(objects_file_path)

### Get image identifiers of images in directory: 

In [None]:
image_ids_present = get_image_ids(data_proc_path)
image_ids_present.sort()
len(image_ids_present)

### Choose the desired objects from visual genome and attribute new object class name:

In [None]:
#desired_objects = ['forest', 'mountain', 'mountains', 'building', 'house', 
#                   'church', 'city', 'village', 'lake', 'river', 'stream', 'glacier']

#desired_objects = ['mountain']
#desired_objects = ['church']
#desired_objects = ['lighthouse']


In [None]:
#desired_objects = ['mountain', 'mountains', 'hill', 'hills', 
#                       'church', 'city', 'village', 'lake', 'river', 'stream', 
#                       'glacier', 'water body', 'watercourse', 'waters', 'man']

#desired_objects = ['church']

#desired_objects = ['mountain', 'mountains']


desired_objects = ['house', 'building', 'village', 'city', 'church']
new_object_class_name = 'buildings'

### Get file paths with desired object class:

In [None]:
selected_image_ids_with = select_img_ids_with_obj(desired_objects, data_proc_path, objects, image_ids_present)
number_imgs_with = len(selected_image_ids_with)
print(number_imgs_with)
print(selected_image_ids_with[0:3])

In [None]:
file_paths_with = get_file_paths_by_id(data_proc_path, selected_image_ids_with, '.jpg')
print(len(file_paths_with))
print(file_paths_with[0:3])

In [None]:
file_paths_with[0:3]

In [None]:
iter_count = 0
for file_path in file_paths_with:

    image = Image.open(file_path)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    iter_count += 1
    if iter_count > 3:
        break

### Make labels for images with object class: 

In [None]:
labels_with = [1] * len(file_paths_with)
print(len(labels_with))
print(labels_with[0:3])

### Get file paths without desired object class:

In [None]:
selected_image_ids_without_o = select_img_ids_with_obj(desired_objects, data_proc_path, objects, image_ids_present, with_obj=False)
len(selected_image_ids_without_o)

In [None]:
import random

In [None]:
# Randomly draw the same number of images without as there are images with the object class:
selected_image_ids_without = random.sample(selected_image_ids_without_o, number_imgs_with)

In [None]:
file_paths_without = get_file_paths_by_id(data_proc_path, selected_image_ids_without, '.jpg')
print(len(file_paths_without))
print(file_paths_without[0:3])

In [None]:
iter_count = 0
for file_path in file_paths_without:

    image = Image.open(file_path)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    iter_count += 1
    if iter_count > 10:
        break

### Make labels for images without object class: 

In [None]:
labels_without = [0] * len(file_paths_without)
print(len(labels_without))
print(labels_without[0:3])

### Make meta data file: 

In [None]:
selected_image_ids_all = selected_image_ids_with + selected_image_ids_without
file_paths_all = file_paths_with + file_paths_without
labels_all = labels_with + labels_without

In [None]:
print(len(selected_image_ids_all))
print(len(file_paths_all))
print(len(labels_all))

In [None]:
#labels = pd.DataFrame({'image_id': selected_image_ids_all, 'file_paths': file_paths_all, 'mountains': labels_all})
labels = pd.DataFrame({'image_id': selected_image_ids_all, 'file_paths': file_paths_all, new_object_class_name: labels_all})

labels.head()

In [None]:
labels_file_path = data_proc_path / 'labels.csv'

In [None]:
labels.to_csv(labels_file_path, index=False, sep=',', encoding='utf-8')

In [None]:
data_path