Creates images where the sample stickers were placed on the stock images. Uses the previously calculated relevant grid cells. Creates large/medium/small sign versions for each sign.

In [None]:
import json
from collections import Counter
from pathlib import Path

from PIL import Image, ImageDraw, ImageFilter

In [None]:
# Sign to sticker ratio
sign_shapes = {
    "circle": {
        "bochum": {"large": "75x75.png", "medium": "93x93.png", "small": "133x133.png"},
        "fortuna": {
            "large": "95x197.png",
            "medium": "119x247.png",
            "small": "170x352.png",
        },
    },
    "triangle": {
        "bochum": {"large": "44x44.png", "medium": "62x62.png", "small": "89x89.png"},
        "fortuna": {
            "large": "57x117.png",
            "medium": "79x164.png",
            "small": "113x235.png",
        },
    },
    "hexagon": {
        "bochum": {"large": "53x53.png", "medium": "62x62.png", "small": "62x62.png"},
        "fortuna": {
            "large": "68x141.png",
            "medium": "79x164.png",
            "small": "79x164.png",
        },
    },
    "square": {
        "bochum": {"large": "67x67.png", "medium": "93x93.png", "small": "133x133.png"},
        "fortuna": {
            "large": "85x176.png",
            "medium": "119x247.png",
            "small": "170x352.png",
        },
    },
}

In [None]:
def get_cell_dict(path_csv: str) -> dict:
    """
    Returns a dictionary with the most relevant cell for each class
    
    path_csv : str
    Path to result csv from masking prozess
    """
    with open(path_csv) as json_file:
        data = json.load(json_file)

    most_relevant_cell_dict = {}
    for _class in range(43):
        _class = str(_class).zfill(5)
        most_relevant_cell_lst = []
        for _, cell_lst in data[_class].items():
            most_relevant_cell_lst.append(cell_lst[0])

        most_relevant_cell_dict[_class] = Counter(most_relevant_cell_lst).most_common()[
            0
        ][0]

    return most_relevant_cell_dict

In [None]:
def get_x_y_coordinates(cell: int):
    """
    Calculates coordinates of the passed cell
    
    cell : int 
    """
    row = int(cell / 8)
    col = cell % 8
    x_coordinate = (col) * 100
    y_coordinate = (row) * 100

    return x_coordinate, y_coordinate

In [None]:
# Adjust the paths!

sticker_folder = Path(r"/Users/robin/Downloads/GTSRB_Visualization/data/raw_sticker")
stock_images = [
    x
    for x in Path(
        r"/Users/robin/Downloads/GTSRB_Visualization/data/sticker/original"
    ).glob("**/*")
    if x.is_file()
]

# Get dict with relevant cells
cell_dict = get_cell_dict(
    "/Users/robin/Downloads/content-6/results/content/masking_jsons/heatmap_grad_cam_pp__heatmap_masked.csv"
)


for size in ["large", "medium", "small"]:
    for sticker_type in ["fortuna", "bochum"]:
        for stock_image in stock_images:
            if stock_image.name == ".DS_Store":
                continue

            # Get coordinates
            x_coordinate, y_coordinate = get_x_y_coordinates(
                cell_dict[stock_image.parent.name]
            )

            # Get correct sticker
            sticker = sign_shapes[stock_image.stem.split("_")[0]][sticker_type]["large"]
            sticker_path = sticker_folder.joinpath(sticker)

            # Load images
            im1 = Image.open(stock_image)
            im2 = Image.open(sticker_path)

            # Correct coordinates if sticker would reach over
            if im2.size[0] + x_coordinate > 799:
                x_coordinate = x_coordinate - (im2.size[0] + x_coordinate - 799)

            if im2.size[1] + y_coordinate > 799:
                y_coordinate = y_coordinate - (im2.size[1] + y_coordinate - 799)

            # Paste sticker and store image
            im1.paste(im2, (x_coordinate, y_coordinate))
            trg = sticker_folder.joinpath(
                "generated_sticker",
                sticker_type,
                size + "_sign",
                stock_image.parent.name,
                stock_image.name,
            )
            trg.parent.mkdir(parents=True, exist_ok=True)
            im1.save(str(trg))