In [1]:
import pandas as pd
import seaborn as sns
import numpy as np
import os
import sys
import re

from PIL import Image
from matplotlib import pyplot as plt
from dotenv import load_dotenv

load_dotenv()
sns.color_palette('colorblind')
# plt.style.use('Solarize_Light2')

sys.path.append("../")

from scripts.process_batch_images import process_batch
from scripts.augment_batch import augment_mp

# Setting default DPI, pulling it from dotenv if it exists, setting it on 100 if not

try:
    pc_dpi = int(os.getenv('DPI'))
except TypeError:
    pc_dpi = 100
if pc_dpi is None:
    pc_dpi = 100


# Data adjustment and augmentation :

## 1 - Size adjust for model input
- Adjusting the image to kick away the scale : that ensures that the model will not be trained on the scale position/label/text whatever - We have room to spare since the images are all high res.
- The adjusted images will be reduced to a more manageable size to be fed to the model (400px*400px), square crop focused on center.

## 2 - Simple 'augmentation' :
- The initial image aspect ratio is 4:3 @ 4K. That means the center crop will leave details on the left and right side of the image out of the output.
- We can both get that information and augment the size of our dataset while staying meaningful : making a sample on left and right sides of the image makes 1 sample 3 output images.
- There doesn't seem to be a constant scale between the images so we might as well get zooms. We  will base our sampling on 400*400px (input size for model), cropping while making sure the average pixel value of the crop is somewhat similar to the initial image.

## 3 - Data augmentation :
- Will run standard data augmentation techniques to enlarge the dataset while keeping it meaningful and keeping overfitting in mind
- Since the images

In [2]:
test_image_path = "../imgs/MA184.jpg"

test_image = Image.open(fp=test_image_path)


In [None]:
display(test_image)


In [None]:
# crop percentage
crop_percentage = 0.075

# Adjust box dimensions
width, height = test_image.size
left = int(width * crop_percentage)
upper = int(height * crop_percentage)
right = int(width * (1 - crop_percentage))
lower = int(height * (1 - crop_percentage))

crop_box = (left, upper, right, lower)
cropped_image = test_image.crop(crop_box)

display(cropped_image)


<i>Bye scale!</i>

In [None]:
display(cropped_image.getbbox())


In [6]:
def rm_scale(image, crop_percentage = 0.075):

    # Adjust box dimensions
    width, height = image.size
    left = int(width * crop_percentage)
    upper = int(height * crop_percentage)
    right = int(width * (1 - crop_percentage))
    lower = int(height * (1 - crop_percentage))

    crop_box = (left, upper, right, lower)
    cropped_image = image.crop(crop_box)

    return cropped_image


def resize_center(image, output_size=(400, 400)):
    width, height = image.size

    new_width = min(width, height)
    left = (width - new_width) / 2
    top = (height - new_width) / 2
    right = (width + new_width) / 2
    bottom = (height + new_width) / 2
    
    # crop @ center
    cropped_image = image.crop((left, top, right, bottom))
    
    # resize at output_size resolution
    resized_image = cropped_image.resize(output_size, Image.Resampling.LANCZOS)
    
    return resized_image

def crop_strips(image, strip_width=800):
    width, height = image.size
    
    # Calculate the center square size
    center_width = width - 2 * strip_width
    left_crop = strip_width
    right_crop = width - strip_width
    
    # Crop the left and right strips
    left_strip = image.crop((0, 0, left_crop, height))
    right_strip = image.crop((right_crop, 0, width, height))
    
    return left_strip, right_strip


def get_average_grid_px(image):
    width, height = image.size
    grid_size = 3  # 3x3 grid

    # Ensure the image is square
    assert width == height, "The image must be square."

    # Step size determination
    step_size = width // grid_size

    grid_avg_values = []

    for i in range(grid_size):
        for j in range(grid_size):
            # Grid cell's BBOX
            left = i * step_size
            top = j * step_size
            right = left + step_size
            bottom = top + step_size

            # Crop the grid
            grid_cell = image.crop((left, top, right, bottom))

            # Averages the pixel value of each cell
            np_grid_cell = np.array(grid_cell)
            avg_r = np.mean(np_grid_cell[:, :, 0])
            avg_g = np.mean(np_grid_cell[:, :, 1])
            avg_b = np.mean(np_grid_cell[:, :, 2])

            # Append the average values as a tuple
            grid_avg_values.append((avg_r, avg_g, avg_b))

    return grid_avg_values


def get_crops_staircase_pattern(strip, crop_size=(400, 400)):
    width, height = strip.size
    crop_width, crop_height = crop_size

    crops = []
    
    # Iterate over the height of the strip
    for y in range(0, height - crop_height + 1, crop_height):
        # Top left crop
        left_crop = strip.crop((0, y, crop_width, y + crop_height))
        crops.append(left_crop)
        
        # Top right crop (aligned to the right edge)
        if width > crop_width:
            right_crop = strip.crop((width - crop_width, y, width, y + crop_height))
            crops.append(right_crop)
    
    # Handle the last line if there's at least 100px remaining
    remaining_height = height % crop_height
    if remaining_height >= 100:
        y = height - crop_height
        left_crop = strip.crop((0, y, crop_width, height))
        crops.append(left_crop)

        if width > crop_width:
            right_crop = strip.crop((width - crop_width, y, width, height))
            crops.append(right_crop)
    
    return crops


def calculate_difference(original_values, crop_values):
    # Sum of absolute differences across RGB channels for each grid point
    return np.sum(np.abs(np.array(original_values) - np.array(crop_values)))


def get_best_crop(original_image, strip, crop_size = (400, 400)):

    pixel_average_original = get_average_grid_px(image=original_image)

    crops = get_crops_staircase_pattern(strip=strip, crop_size=crop_size)

    best_crop = None
    best_delta = float("inf")

    for crop in crops:
        delta = calculate_difference(original_values=pixel_average_original, crop_values=get_average_grid_px(image=crop))
        if delta < best_delta:
            best_delta = delta
            best_crop = crop

    return best_crop


def process_image(image_path: str, save_directory: str) -> None:
    # crop to remove scale :
    image = Image.open(fp=image_path)
    image_name = image_path.split("/")[-1].split(".")[0]
    image = rm_scale(image=image)

    # center main image
    centered_image = resize_center(image=image)

    left_strip, right_strip = crop_strips(image=image)

    best_left = get_best_crop(original_image=centered_image, strip=left_strip)
    best_right = get_best_crop(original_image=centered_image, strip=right_strip)

    rotate_left = np.random.choice(np.arange(0, 360, 90))
    rotate_right = np.random.choice(np.arange(0, 360, 90))

    if rotate_left != 0:
        best_left = best_left.rotate(rotate_left)

    if rotate_right != 0:
        best_right = best_right.rotate(rotate_right)

    # Saving :
    if save_directory[-1] != "/":
        save_directory = save_directory + "/"

    best_left.save(fp=save_directory + image_name + "l" + ".jpg", format="JPEG")
    best_right.save(fp=save_directory + image_name + "r" + ".jpg", format="JPEG")
    centered_image.save(fp=save_directory + image_name + "c" + ".jpg", format="JPEG")


def augment(images: Image.Image, output_nbr: int = 0, sub_divide: bool = False, augment_subs: bool = False) -> list:
    ...

In [None]:
centered_image = resize_center(image=cropped_image)
display(centered_image)


In [8]:
left_crop, right_crop = crop_strips(image=cropped_image)

In [None]:
print(left_crop.getbbox())
display(left_crop)


In [None]:
print(right_crop.getbbox())
display(right_crop)


In [11]:
average_px_test_center = get_average_grid_px(image=resize_center(image=cropped_image))


In [12]:
left_crops = get_crops_staircase_pattern(strip=left_crop)
right_crops = get_crops_staircase_pattern(strip=right_crop)


In [None]:
for crop in left_crops:
    display(crop)

In [None]:
for crop in right_crops:
    display(crop)

In [None]:
left_best = get_best_crop(original_image=centered_image, strip=left_crop)
display(left_best)


In [None]:
right_best = get_best_crop(original_image=centered_image, strip=right_crop)
display(right_best)


In [None]:
# Test mp image processing :

image_dir = "../imgs/"
save_dir = "../data/processed_images/"

process_batch(image_directory=image_dir, save_directory=save_dir, max_workers=8)


# Augmentations on minority classes :
- Flip
- Random Crop
- Random rotate
- Scaling

High minority classes (above 100) will have one augment
Mid minority will have 2
Low minority will have 4

In [18]:
# load updated df :
df = pd.read_pickle("../data/work_met_img_type_2.pkl")


In [19]:
def get_image_count(df):
    df_exploded = df.explode("images")
    
    image_counts = df_exploded["mtype"].value_counts()
    
    return image_counts


In [None]:
get_image_count(df=df)


In [None]:
augment_mp()


In [24]:
import os
import pandas as pd

def update_dataframe_with_augmented_images(df, augmented_images_dir):
    """
    Updates the DataFrame to include augmented images with their corresponding work_name and mtype.

    :param df: Original DataFrame with columns 'work_name', 'mtype', 'images'
    :param augmented_images_dir: Directory containing augmented images
    :return: The updated DataFrame (modified in place)
    """
    # Ensure 'images' column is a list
    df['images'] = df['images'].apply(lambda x: x if isinstance(x, list) else [x])

    # Collect all augmented image filenames
    augmented_image_files = [
        f for f in os.listdir(augmented_images_dir)
        if os.path.isfile(os.path.join(augmented_images_dir, f))
    ]

    # For quick access, create a set of augmented image filenames
    augmented_image_set = set(augmented_image_files)

    # For each row in df, update the 'images' list
    for index, row in df.iterrows():
        original_images = row['images']
        new_images = []
        for orig_image_name in original_images:
            # Remove extension and directory from original image name
            orig_image_base = os.path.splitext(os.path.basename(orig_image_name))[0]
            # Find augmented images that start with the original image base name followed by an underscore
            matching_augmented_images = [
                aug_image for aug_image in augmented_image_set
                if aug_image.startswith(orig_image_base + '_')
            ]
            new_images.extend(matching_augmented_images)
        # Add the matching augmented images to the 'images' list for this row
        df.at[index, 'images'].extend(new_images)

    # Remove duplicates from 'images' lists
    df['images'] = df['images'].apply(lambda x: list(set(x)))

    return df

# Usage example:

# Assuming df is already loaded and augmented images are in 'augmented_images_dir'
augmented_images_dir = "../data/processed_images/"


df_augment = update_dataframe_with_augmented_images(df, augmented_images_dir)


In [None]:
df_augment.head()


In [None]:
get_image_count(df=df_augment)
