# Prerequisite

In [None]:
import os
import re
import json
import shutil
import random
from PIL import Image, ImageDraw, ImageFont

from torchvision.datasets import ImageNet
import torchvision.transforms as transforms

os.chdir('/hpc2hdd/home/erjiaxiao/erjia/LLaVA')     # 设置新的工作目录

# Generate Examples - Species - Repetition

In [None]:
def add_text_img(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)
max_num_prints = 4

dataset = ImageNet(root='datasets/ImageNet/', split='val', transform=transforms.Compose([transforms.Resize(resolution)]))

image_root = 'images'
species_repetition_dir = 'species'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r)))

# 定义阈值，达到这个次数后才继续下一标签的图片输入，使用字典来跟踪标签出现的次数
threshold = 1
duplicate = {}
max_num = 500
count = 0

# 遍历前 num 张图像和标签
for i, (image, label) in enumerate(dataset):

    if count >= max_num:
        break

    # 从类别 index 获取图像的类别名
    label = dataset.classes[label][0]

    # 检查标签是否在字典中
    if label in duplicate:
        # 如果标签出现次数未达到阈值，增加计数
        if duplicate[label] < threshold:
            duplicate[label] += 1
        else:
            # 如果标签已经达到阈值，继续下一张图片
            continue
    else:
        # 如果标签不在字典中，添加并初始化计数
        duplicate[label] = 1

    # 任意选择其它种类
    mislabel = random.choice(dataset.classes)[0]
    while mislabel == label:
        mislabel = random.choice(dataset.classes)[0]

    # 处理图像
    for r in range(max_num_prints):
        img = add_text_img(image=image.copy(), text=mislabel, num_prints=r, font_path=font_path, font_size=font_size, font_color=font_color)
        img.save(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r), str(i) + '-' + label + '-' + mislabel + '.jpg'))

    count += 1

# Generate Examples - Color - Repetition

In [None]:
# Object Color
def add_text_img(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

# 读取文件内容
with open('datasets/DAQUAR_QA.txt', 'r') as file:
    text = file.read()

lines = text.strip().split('\n')
color_questions = []
for i, line in enumerate(lines):
    if "what color" in line:
        question = re.sub(r'image\d+', 'image', line)
        image_name = re.search(r'image\d+', line).group(0)
        color = lines[i + 1]
        
        # 检查颜色行中是否存在多个颜色
        if color.count(',') > 0:
            continue
        
        if len(color) <= 2:
            continue
        
        color_questions.append((question, image_name, color))

colors = set()
for _, _, color in color_questions:
    colors.add(color.strip())
colors = list(colors)

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)
max_num_prints = 4

image_root = 'images'
color_repetition_dir = 'color'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r)))

for i, (question, img, color) in enumerate(color_questions):

    image_path = f"datasets/DAQUAR/{img}.png"
    image = Image.open(image_path)
        
    miscolor = random.choice(colors)
    while miscolor == color:
        miscolor = random.choice(colors)
        
    for r in range(max_num_prints):   
        img = add_text_img(image=image.copy(), text=miscolor, num_prints=r, font_path=font_path, font_size=font_size, font_color=font_color)
        img.save(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r), str(i) + '-' + question + '-' + color + '-' + miscolor + '.jpg'))


In [None]:
# Visual7W
def add_text_img(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_num_prints = 2
max_num = 5000

image_root = 'images'
color_repetition_dir = 'color-large'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r)))

visual7w_json = 'datasets/dataset_v7w_telling.json'
visual7w_img_folder = 'datasets/Visual7W/'

with open(visual7w_json, 'r') as file:
    data = json.load(file)

k = 0
for image in data["images"]: 
    for qa_pair in image["qa_pairs"]:
        if k>=max_num:
            break
        
        image_path = os.path.join(visual7w_img_folder, 'v7w_'+str(qa_pair["image_id"])+'.jpg')
        question = qa_pair["question"]
        
        if 'what color' not in question.lower():
            continue
        
        answer = qa_pair["answer"]
        multiple_choices = qa_pair["multiple_choices"]
        wrong_ans = random.choice([choice for choice in multiple_choices if choice != answer and choice])

        if '/' in answer or '/' in wrong_ans:
            continue

        image = Image.open(image_path)
            
        for r in range(max_num_prints):
            img = add_text_img(image=image.copy(), text=wrong_ans, num_prints=r, font_path=font_path, font_size=font_size, font_color=font_color)
            img.save(os.path.join(image_root, color_repetition_dir + '-' + 'r' + str(r), str(k) + '-' + question + '-' + answer + '-' + wrong_ans + '.jpg'))

        k+=1

# Generate Examples - Counting - Repetition

In [None]:
# CountBench Dataset
def add_text_img(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

number_to_word = {
    '1': 'one',
    '2': 'two',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten'
}

countbench_dir = 'datasets/LabeledCountBench/'

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_num_prints = 4
resolution = (336, 336)

image_root = 'images'
counting_repetition_dir = 'counting'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r)))

# 遍历处理文件夹中的图像文件
for i, filename in enumerate(os.listdir(countbench_dir)):
    if filename.endswith(('.jpg', '.png', '.jpeg', '.gif')):  

        image_path = os.path.join(countbench_dir, filename)
        image = Image.open(image_path)
        image = image.resize(resolution)
        
        label = filename.split('.')[0].split('_')[-2]
        counting = filename.split('.')[0].split('_')[-1]
        
        miscounting = str(random.randint(1, 10))
        while miscounting == counting:
            miscounting = str(random.randint(1, 10))

        counting = number_to_word.get(counting, 'Invalid')
        miscounting = number_to_word.get(miscounting, 'Invalid')

        for r in range(max_num_prints):
            img = add_text_img(image=image.copy(), text=miscounting, num_prints=r, font_path=font_path, font_size=font_size, font_color=font_color)
            img.save(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r), str(i) + '-' + label + '-' + counting + '-' + miscounting + '.jpg'))

In [None]:
# Tallyqa Dataset
def add_text_img(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

number_to_word = {
    '1': 'one',
    '2': 'two',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten'
}

vg_dir = 'datasets/Visual_Genome/'
tallyqa_json = 'datasets/tallyqa/test.json'

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_num_prints = 2
resolution = (336, 336)
max_num = 5000

image_root = 'images'
counting_repetition_dir = 'counting-large'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r)))

with open(tallyqa_json, 'r') as file:
    data = json.load(file)
    count = 0
    for i, item in enumerate(data):
        
        if count>=max_num:
            break
        
        image_path = os.path.join(vg_dir, item["image"])
        answer = item["answer"]
        question = item["question"]
        
        if answer > 10 or answer <= 0:
            continue

        image = Image.open(image_path).resize(resolution)
        
        counting = str(answer)
        miscounting = str(random.randint(1, 10))
        while miscounting == counting:
            miscounting = str(random.randint(1, 10))

        counting = number_to_word.get(counting, 'Invalid')
        miscounting = number_to_word.get(miscounting, 'Invalid')

        for r in range(max_num_prints):
            img = add_text_img(image=image.copy(), text=miscounting, num_prints=r, font_path=font_path, font_size=font_size, font_color=font_color)
            img.save(os.path.join(image_root, counting_repetition_dir + '-' + 'r' + str(r), str(i) + '-' + question + '-' + counting + '-' + miscounting + '.jpg'))
            
        count+=1

# Generate Examples - Numerical - Repetition

In [None]:
def add_equation(text, font_path, num_prints, font_size, font_color, bg_color, resolution):

    image = Image.new('RGB', resolution, bg_color)
    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

def add_wrong_answer(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2 - 100  # 在图片中心的正上方100pixel处

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
equation_font_color = 'black'
mislabel_font_color = 'red'
bg_color = 'white'
resolution = (336, 336)
max_num_prints = 4
max_pic_num = 500

image_root = 'images'
species_repetition_dir = 'numerical'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r)))

for i in range(max_pic_num):

    num1 = random.randint(1, 1000)
    num2 = random.randint(1, 1000)
    operators = ['+', '-', 'x', '÷']
    chosen_operator = random.choice(operators)
    correct_answer = 0

    if chosen_operator == '+':
        correct_answer = num1 + num2
    elif chosen_operator == '-':
        # 确保num1大于num2，这样结果不会是负数
        num1, num2 = max(num1, num2), min(num1, num2)
        correct_answer = num1 - num2
    elif chosen_operator == 'x':
        correct_answer = num1 * num2
    elif chosen_operator == '÷':
        factors = [i for i in range(1, num1 + 1) if num1 % i == 0]
        if factors:
            num2 = random.choice(factors)
            correct_answer = num1 // num2

    # 生成一个相近的错误答案
    wrong_answer = correct_answer + random.randint(-500, 500)
    # 确保wrong_answer不是负数
    wrong_answer = max(0, wrong_answer)

    equation = f"{num1} {chosen_operator} {num2} ="

    # 处理图像
    for r in range(max_num_prints):
        img = add_equation(text=equation, num_prints=1, font_path=font_path, font_size=font_size, font_color=equation_font_color, bg_color=bg_color, resolution=resolution)
        img = add_wrong_answer(image=img.copy(), text=str(wrong_answer), num_prints=r, font_path=font_path, font_size=font_size, font_color=mislabel_font_color)
        img.save(os.path.join(image_root, species_repetition_dir + '-' + 'r' + str(r), str(i) + '-' + str(correct_answer) + '-' + str(wrong_answer) + '.jpg'))

# Generate Examples - Complex - Repetition

In [None]:
def add_text_img(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]
    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_num_prints = 4
max_pic_num = 500
resolution = (336, 336)

image_root = 'images'
complex_repetition_dir = 'complex'

for r in range(max_num_prints):
    if os.path.exists(os.path.join(image_root, complex_repetition_dir + '-' + 'r' + str(r))):
        shutil.rmtree(os.path.join(image_root, complex_repetition_dir + '-' + 'r' + str(r)))
    os.mkdir(os.path.join(image_root, complex_repetition_dir + '-' + 'r' + str(r)))

dataset_path = 'datasets/COCO2017/val2017/'
json_path = 'datasets/aokvqa_v1p0_val.json'

# 初始化一个空字典用于存储映射
img_dict = {}

# 遍历文件夹中的文件
for filename in os.listdir(dataset_path):
    if filename.endswith('.jpg'):
            prefix = int(filename.split('.')[0])
            img_dict[prefix] = dataset_path + filename
                      
# 读取JSON文件
with open(json_path, 'r') as file:
    data = json.load(file)

# 遍历JSON数据
k = 0
for item in data:
    
    if k >= max_pic_num:
        break
    
    image_id = item['image_id']
    question = item['question']
    choices = item['choices']
    correct_choice_idx = item['correct_choice_idx']
    
    # 有的问题里会存在/影响路径识别
    if '/' in question:
        continue
    
    image_path = img_dict[image_id]
    img = Image.open(image_path)
    img = img.resize(resolution)
    
    label = choices[correct_choice_idx].split('/')[0]   # 有的选项里会存在/影响路径识别
    mislabel = random.choice(choices).split('/')[0]
    
    count = 0
    while label == mislabel and count < 4:
        mislabel = random.choice(choices).split('/')[0]
        count += 1
    if count >= 4:
        continue
    
    for r in range(max_num_prints):
        img_mislabel = add_text_img(image=img.copy(), text=mislabel, num_prints=r, font_path=font_path, font_size=font_size, font_color=font_color)
        img_mislabel.save(os.path.join(image_root, complex_repetition_dir + '-' + 'r' + str(r), str(k) + '-' + question + '-' + label + '-' + mislabel + '.jpg'))
        
    k+=1

# Generate Examples - Species - Position

In [None]:
def add_text_img(image, text, font_path, font_size, font_color, position, resolution):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    x, y = position

    # 防止字幕超过左侧边缘
    text_x = max(0, x - text_width // 2)    # 考虑文本宽度
    text_y = max(0, y - text_height // 2)   # 考虑文本高度
    
    # 防止字幕超过右侧边缘
    if text_x + text_width > resolution[0]:
        text_x = resolution[0] - text_width

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

image_root = 'images'
species_position_dir = 'species'

# 将图像等比例分成多个区域
num_rows = 4  
num_cols = 4 
cell_width = resolution[0] // num_cols
cell_height = resolution[1] // num_rows

for row in range(num_rows):
    for col in range(num_cols):
        if os.path.exists(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col))):
            shutil.rmtree(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col)))
        os.mkdir(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col)))
        
dataset = ImageNet(root='datasets/ImageNet/', split='val', transform=transforms.Compose([transforms.Resize(resolution)]))

# 定义阈值，达到这个次数后才继续下一标签的图片输入，使用字典来跟踪标签出现的次数
threshold = 1
duplicate = {}
max_num = 500
count = 0

# 遍历前 num 张图像和标签
for i, (image, label) in enumerate(dataset):

    if count >= max_num:
        break

    # 从类别 index 获取图像的类别名
    label = dataset.classes[label][0]

    # 检查标签是否在字典中
    if label in duplicate:
        # 如果标签出现次数未达到阈值，增加计数
        if duplicate[label] < threshold:
            duplicate[label] += 1
        else:
            # 如果标签已经达到阈值，继续下一张图片
            continue
    else:
        # 如果标签不在字典中，添加并初始化计数
        duplicate[label] = 1

    # 任意选择其它种类
    mislabel = random.choice(dataset.classes)[0]

    # 处理图像
    for row in range(num_rows):
        for col in range(num_cols):
            x = col * cell_width + cell_width // 2
            y = row * cell_height + cell_height // 2
            img = add_text_img(image=image.copy(), text=mislabel, font_path=font_path, font_size=font_size, font_color=font_color, position=(x,y), resolution=resolution)
            img.save(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col), str(i) + '-' + label + '-' + mislabel + '.jpg'))

    count += 1

# Generate Examples - Color - Position

In [None]:
def add_text_img(image, text, font_path, font_size, font_color, position, resolution):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    x, y = position

    # 防止字幕超过左侧边缘
    text_x = max(0, x - text_width // 2)    # 考虑文本宽度
    text_y = max(0, y - text_height // 2)   # 考虑文本高度
    
    # 防止字幕超过右侧边缘
    if text_x + text_width > resolution[0]:
        text_x = resolution[0] - text_width

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

# 读取文件内容
with open('datasets/DAQUAR_QA.txt', 'r') as file:
    text = file.read()

lines = text.strip().split('\n')
color_questions = []
for i, line in enumerate(lines):
    if "what color" in line:
        question = re.sub(r'image\d+', 'image', line)
        image_name = re.search(r'image\d+', line).group(0)
        color = lines[i + 1]
        
        # 检查颜色行中是否存在多个颜色
        if color.count(',') > 0:
            continue
        
        if len(color) <= 2:
            continue
        
        color_questions.append((question, image_name, color))

colors = set()
for _, _, color in color_questions:
    colors.add(color.strip())
colors = list(colors)

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

image_root = 'images'
species_position_dir = 'color'

# 将图像等比例分成多个区域
num_rows = 4  
num_cols = 4 
cell_width = resolution[0] // num_cols
cell_height = resolution[1] // num_rows

for row in range(num_rows):
    for col in range(num_cols):
        if os.path.exists(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col))):
            shutil.rmtree(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col)))
        os.mkdir(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col)))

for i, (question, img, color) in enumerate(color_questions):
    
    image_path = f"datasets/DAQUAR/{img}.png"
    image = Image.open(image_path)
        
    resolution = image.size
    cell_width = resolution[0] // num_cols
    cell_height = resolution[1] // num_rows
        
    miscolor = random.choice(colors)
    while miscolor == color:
        miscolor = random.choice(colors)
        
    # 处理图像
    for row in range(num_rows):
        for col in range(num_cols):
            x = col * cell_width + cell_width // 2
            y = row * cell_height + cell_height // 2
            img = add_text_img(image=image.copy(), text=miscolor, font_path=font_path, font_size=font_size, font_color=font_color, position=(x,y), resolution=resolution)
            img.save(os.path.join(image_root, species_position_dir + '-' + 'p' + str(row) + str(col), str(i) + '-' + question + '-' + color + '-' + miscolor + '.jpg'))

# Generate Examples - Counting - Position

In [None]:
def add_text_img(image, text, font_path, font_size, font_color, position, resolution):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    x, y = position

    # 防止字幕超过左侧边缘
    text_x = max(0, x - text_width // 2)    # 考虑文本宽度
    text_y = max(0, y - text_height // 2)   # 考虑文本高度
    
    # 防止字幕超过右侧边缘
    if text_x + text_width > resolution[0]:
        text_x = resolution[0] - text_width

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

number_to_word = {
    '1': 'one',
    '2': 'two',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten'
}

countbench_dir = 'datasets/LabeledCountBench/'

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_num_prints = 4
resolution = (336, 336)

image_root = 'images'
counting_position_dir = 'counting'

# 将图像等比例分成多个区域
num_rows = 4  
num_cols = 4 
cell_width = resolution[0] // num_cols
cell_height = resolution[1] // num_rows

for row in range(num_rows):
    for col in range(num_cols):
        if os.path.exists(os.path.join(image_root, counting_position_dir + '-' + 'p' + str(row) + str(col))):
            shutil.rmtree(os.path.join(image_root, counting_position_dir + '-' + 'p' + str(row) + str(col)))
        os.mkdir(os.path.join(image_root, counting_position_dir + '-' + 'p' + str(row) + str(col)))

# 遍历处理文件夹中的图像文件
for i, filename in enumerate(os.listdir(countbench_dir)):
    if filename.endswith(('.jpg', '.png', '.jpeg', '.gif')):  

        image_path = os.path.join(countbench_dir, filename)
        image = Image.open(image_path)
        image = image.resize(resolution)
        
        label = filename.split('.')[0].split('_')[-2]
        counting = filename.split('.')[0].split('_')[-1]
        
        miscounting = None
        while True:
            miscounting = str(random.randint(1, 10))
            if miscounting != counting:
                break

        counting = number_to_word.get(counting, 'Invalid')
        miscounting = number_to_word.get(miscounting, 'Invalid')

        for row in range(num_rows):
            for col in range(num_cols):
                x = col * cell_width + cell_width // 2
                y = row * cell_height + cell_height // 2
                img = add_text_img(image=image.copy(), text=miscounting, font_path=font_path, font_size=font_size, font_color=font_color, position=(x, y), resolution=resolution)
                img.save(os.path.join(image_root, counting_position_dir + '-' + 'p' + str(row) + str(col), str(i) + '-' + label + '-' + counting + '-' + miscounting + '.jpg'))

# Generate Examples - Numerical - Position

In [None]:
def add_equation(text, font_path, num_prints, font_size, font_color, bg_color, resolution):

    image = Image.new('RGB', resolution, bg_color)
    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

def add_wrong_answer(image, text, font_path, font_size, font_color, position, resolution):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    x, y = position

    # 防止字幕超过左侧边缘
    text_x = max(0, x - text_width // 2)    # 考虑文本宽度
    text_y = max(0, y - text_height // 2)   # 考虑文本高度
    
    # 防止字幕超过右侧边缘
    if text_x + text_width > resolution[0]:
        text_x = resolution[0] - text_width

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
equation_font_color = 'black'
mislabel_font_color = 'red'
bg_color = 'white'
resolution = (336, 336)
max_pic_num = 500

image_root = 'images'
position_dir = 'numerical'

# 将图像等比例分成多个区域
num_rows = 4  
num_cols = 4 
cell_width = resolution[0] // num_cols
cell_height = resolution[1] // num_rows

for row in range(num_rows):
    for col in range(num_cols):
        if os.path.exists(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col))):
            shutil.rmtree(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col)))
        os.mkdir(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col)))

for i in range(max_pic_num):

    num1 = random.randint(1, 1000)
    num2 = random.randint(1, 1000)
    operators = ['+', '-', 'x', '÷']
    chosen_operator = random.choice(operators)
    correct_answer = 0

    if chosen_operator == '+':
        correct_answer = num1 + num2
    elif chosen_operator == '-':
        # 确保num1大于num2，这样结果不会是负数
        num1, num2 = max(num1, num2), min(num1, num2)
        correct_answer = num1 - num2
    elif chosen_operator == 'x':
        correct_answer = num1 * num2
    elif chosen_operator == '÷':
        factors = [i for i in range(1, num1 + 1) if num1 % i == 0]
        if factors:
            num2 = random.choice(factors)
            correct_answer = num1 // num2

    # 生成一个相近的错误答案
    wrong_answer = correct_answer + random.randint(-500, 500)
    # 确保wrong_answer不是负数
    wrong_answer = max(0, wrong_answer)

    equation = f"{num1} {chosen_operator} {num2} ="

    # 处理图像
    for row in range(num_rows):
        for col in range(num_cols):
            x = col * cell_width + cell_width // 2
            y = row * cell_height + cell_height // 2
            img = add_equation(text=equation, num_prints=1, font_path=font_path, font_size=font_size, font_color=equation_font_color, bg_color=bg_color, resolution=resolution)
            img = add_wrong_answer(image=img.copy(), text=str(wrong_answer), font_path=font_path, font_size=font_size, font_color=mislabel_font_color, position=(x, y), resolution=resolution)
            img.save(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col), str(i) + '-' + str(correct_answer) + '-' + str(wrong_answer) + '.jpg'))

# Generate Examples - Complex - Position

In [None]:
def add_text_img(image, text, font_path, font_size, font_color, position, resolution):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    x, y = position

    # 防止字幕超过左侧边缘
    text_x = max(0, x - text_width // 2)    # 考虑文本宽度
    text_y = max(0, y - text_height // 2)   # 考虑文本高度
    
    # 防止字幕超过右侧边缘
    if text_x + text_width > resolution[0]:
        text_x = resolution[0] - text_width

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_pic_num = 500
resolution = (336, 336)

image_root = 'images'
position_dir = 'complex'

# 将图像等比例分成多个区域
num_rows = 4  
num_cols = 4 
cell_width = resolution[0] // num_cols
cell_height = resolution[1] // num_rows

for row in range(num_rows):
    for col in range(num_cols):
        if os.path.exists(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col))):
            shutil.rmtree(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col)))
        os.mkdir(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col)))

dataset_path = 'datasets/COCO2017/val2017/'
json_path = 'datasets/aokvqa_v1p0_val.json'

# 初始化一个空字典用于存储映射
img_dict = {}

# 遍历文件夹中的文件
for filename in os.listdir(dataset_path):
    if filename.endswith('.jpg'):
            prefix = int(filename.split('.')[0])
            img_dict[prefix] = dataset_path + filename
                      
# 读取JSON文件
with open(json_path, 'r') as file:
    data = json.load(file)

count = 0
# 遍历JSON数据
for i, item in enumerate(data):
    
    if count >= max_pic_num:
        break
    
    image_id = item['image_id']
    question = item['question']
    choices = item['choices']
    correct_choice_idx = item['correct_choice_idx']
    
    # 有的问题里会存在/影响路径识别
    if '/' in question:
        continue
    
    image_path = img_dict[image_id]
    img = Image.open(image_path)
    img = img.resize(resolution)
    
    label = choices[correct_choice_idx].split('/')[0]   # 有的选项里会存在/影响路径识别
    mislabel = random.choice(choices).split('/')[0]
    while label == mislabel:
        mislabel = random.choice(choices).split('/')[0]
    
    # 处理图像
    for row in range(num_rows):
        for col in range(num_cols):
            x = col * cell_width + cell_width // 2
            y = row * cell_height + cell_height // 2
            img_mislabel = add_text_img(image=img.copy(), text=mislabel, font_path=font_path, font_size=font_size, font_color=font_color, position=(x, y), resolution=resolution)
            img_mislabel.save(os.path.join(image_root, position_dir + '-' + 'p' + str(row) + str(col), str(count) + '-' + question + '-' + label + '-' + mislabel + '.jpg'))
    count+=1

# Generate Examples - Species - Font Color

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
resolution = (336, 336)

dataset = ImageNet(root='datasets/ImageNet/', split='val', transform=transforms.Compose([transforms.Resize(resolution)]))

image_root = 'images'
font_colors = {
    'red': (255, 0, 0),
    'orange': (255, 128, 0),
    'yellow': (255, 255, 0),
    'green': (0, 255, 0),
    'cyan': (0, 255, 255),
    'blue': (0, 0, 255),
    'purple': (128, 0, 255),
    'pink': (255, 0, 255),
     
    'lred': (255, 128, 128),
    'dred': (128, 0, 0),
    
    'lorange': (255, 192, 128),
    'dorange': (128, 64, 0),
    
    'lyellow': (255, 255, 128),
    'dyellow': (128, 128, 0),
    
    'lgreen': (128, 255, 128),
    'dgreen': (0, 128, 0),
    
    'lcyan': (128, 255, 255),
    'dcyan': (0, 128, 128),
    
    'lblue': (128, 128, 255),
    'dblue': (0, 0, 128),
    
    'lpurple': (192, 128, 255),
    'dpurple': (64, 0, 128),
    
    'lpink': (255, 128, 255),
    'dpink': (128, 0, 128),
    
    'grey': (128, 128, 128),
    'white': (255, 255, 255),
    'black': (0, 0, 0),
}

species_font_color_dir = 'species'

for fc in font_colors.keys():
    if os.path.exists(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc)):
        shutil.rmtree(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))
    os.mkdir(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))

# 定义阈值，达到这个次数后才继续下一标签的图片输入，使用字典来跟踪标签出现的次数
threshold = 1
duplicate = {}
max_num = 500
count = 0

# 遍历前 num 张图像和标签
for i, (image, label) in enumerate(dataset):

    if count >= max_num:
        break

    # 从类别 index 获取图像的类别名
    label = dataset.classes[label][0]

    # 检查标签是否在字典中
    if label in duplicate:
        # 如果标签出现次数未达到阈值，增加计数
        if duplicate[label] < threshold:
            duplicate[label] += 1
        else:
            # 如果标签已经达到阈值，继续下一张图片
            continue
    else:
        # 如果标签不在字典中，添加并初始化计数
        duplicate[label] = 1

    # 任意选择其它种类
    mislabel = random.choice(dataset.classes)[0]

    # 处理图像
    for fc, rgb in font_colors.items():
        img_mislabel_colored = add_text_img(image=image.copy(), text=mislabel, font_path=font_path, font_size=font_size, font_color=rgb)
        img_mislabel_colored.save(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc, str(i) + '-' + label + '-' + mislabel + '.jpg'))

    count += 1

# Generate Examples - Color - Font Color

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_colors = {
    'red': (255, 0, 0),
    'orange': (255, 128, 0),
    'yellow': (255, 255, 0),
    'green': (0, 255, 0),
    'cyan': (0, 255, 255),
    'blue': (0, 0, 255),
    'purple': (128, 0, 255),
    'pink': (255, 0, 255),
     
    'lred': (255, 128, 128),
    'dred': (128, 0, 0),
    
    'lorange': (255, 192, 128),
    'dorange': (128, 64, 0),
    
    'lyellow': (255, 255, 128),
    'dyellow': (128, 128, 0),
    
    'lgreen': (128, 255, 128),
    'dgreen': (0, 128, 0),
    
    'lcyan': (128, 255, 255),
    'dcyan': (0, 128, 128),
    
    'lblue': (128, 128, 255),
    'dblue': (0, 0, 128),
    
    'lpurple': (192, 128, 255),
    'dpurple': (64, 0, 128),
    
    'lpink': (255, 128, 255),
    'dpink': (128, 0, 128),
    
    'grey': (128, 128, 128),
    'white': (255, 255, 255),
    'black': (0, 0, 0),
}

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

image_root = 'images'
species_font_color_dir = 'color'

for fc in font_colors.keys():
    if os.path.exists(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc)):
        shutil.rmtree(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))
    os.mkdir(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))

# 读取文件内容
with open('datasets/DAQUAR_QA.txt', 'r') as file:
    text = file.read()

lines = text.strip().split('\n')
color_questions = []
for i, line in enumerate(lines):
    if "what color" in line:
        question = re.sub(r'image\d+', 'image', line)
        image_name = re.search(r'image\d+', line).group(0)
        color = lines[i + 1]
        
        # 检查颜色行中是否存在多个颜色
        if color.count(',') > 0:
            continue
        
        if len(color) <= 2:
            continue
        
        color_questions.append((question, image_name, color))

colors = set()
for _, _, color in color_questions:
    colors.add(color.strip())
colors = list(colors)

for question, img, color in color_questions:

    image_path = f"datasets/DAQUAR/{img}.png"
    image = Image.open(image_path)
        
    miscolor = random.choice(colors)
    while miscolor == color:
        miscolor = random.choice(colors)
        
    # 处理图像
    for fc, rgb in font_colors.items():
        img_mislabel_colored = add_text_img(image=image.copy(), text=miscolor, font_path=font_path, font_size=font_size, font_color=rgb)
        img_mislabel_colored.save(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc, str(i) + '-' + question + '-' + color + '-' + miscolor + '.jpg'))


# Generate Examples - Counting - Font Color

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_colors = {
    'red': (255, 0, 0),
    'orange': (255, 128, 0),
    'yellow': (255, 255, 0),
    'green': (0, 255, 0),
    'cyan': (0, 255, 255),
    'blue': (0, 0, 255),
    'purple': (128, 0, 255),
    'pink': (255, 0, 255),
     
    'lred': (255, 128, 128),
    'dred': (128, 0, 0),
    
    'lorange': (255, 192, 128),
    'dorange': (128, 64, 0),
    
    'lyellow': (255, 255, 128),
    'dyellow': (128, 128, 0),
    
    'lgreen': (128, 255, 128),
    'dgreen': (0, 128, 0),
    
    'lcyan': (128, 255, 255),
    'dcyan': (0, 128, 128),
    
    'lblue': (128, 128, 255),
    'dblue': (0, 0, 128),
    
    'lpurple': (192, 128, 255),
    'dpurple': (64, 0, 128),
    
    'lpink': (255, 128, 255),
    'dpink': (128, 0, 128),
    
    'grey': (128, 128, 128),
    'white': (255, 255, 255),
    'black': (0, 0, 0),
}

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

image_root = 'images'
species_font_color_dir = 'counting'

for fc in font_colors.keys():
    if os.path.exists(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc)):
        shutil.rmtree(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))
    os.mkdir(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))

number_to_word = {
    '1': 'one',
    '2': 'two',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten'
}

countbench_dir = 'datasets/LabeledCountBench/'

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

# 遍历处理文件夹中的图像文件
for i, filename in enumerate(os.listdir(countbench_dir)):
    if filename.endswith(('.jpg', '.png', '.jpeg', '.gif')):  

        image_path = os.path.join(countbench_dir, filename)
        image = Image.open(image_path)
        image = image.resize(resolution)
        
        label = filename.split('.')[0].split('_')[-2]
        counting = filename.split('.')[0].split('_')[-1]
        
        miscounting = str(random.randint(1, 10))
        while miscounting == counting:
            miscounting = str(random.randint(1, 10))

        counting = number_to_word.get(counting, 'Invalid')
        miscounting = number_to_word.get(miscounting, 'Invalid')

        # 处理图像
        for fc, rgb in font_colors.items():
            img_mislabel_colored = add_text_img(image=image.copy(), text=miscounting, font_path=font_path, font_size=font_size, font_color=rgb)
            img_mislabel_colored.save(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc, str(i) + '-' + label + '-' + counting + '-' + miscounting + '.jpg'))

# Generate Examples - Numerical - Font Color

In [None]:
def add_equation(text, font_path, num_prints, font_size, font_color, bg_color, resolution):

    image = Image.new('RGB', resolution, bg_color)
    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

def add_wrong_answer(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2 - 100  # 在图片中心的正上方100pixel处

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

font_colors = {
    'red': (255, 0, 0),
    'orange': (255, 128, 0),
    'yellow': (255, 255, 0),
    'green': (0, 255, 0),
    'cyan': (0, 255, 255),
    'blue': (0, 0, 255),
    'purple': (128, 0, 255),
    'pink': (255, 0, 255),
     
    'lred': (255, 128, 128),
    'dred': (128, 0, 0),
    
    'lorange': (255, 192, 128),
    'dorange': (128, 64, 0),
    
    'lyellow': (255, 255, 128),
    'dyellow': (128, 128, 0),
    
    'lgreen': (128, 255, 128),
    'dgreen': (0, 128, 0),
    
    'lcyan': (128, 255, 255),
    'dcyan': (0, 128, 128),
    
    'lblue': (128, 128, 255),
    'dblue': (0, 0, 128),
    
    'lpurple': (192, 128, 255),
    'dpurple': (64, 0, 128),
    
    'lpink': (255, 128, 255),
    'dpink': (128, 0, 128),
    
    'grey': (128, 128, 128),
    'white': (255, 255, 255),
    'black': (0, 0, 0),
}

font_path = 'fonts/arial_bold.ttf'
font_size = 15
equation_font_color = 'black'
mislabel_font_color = 'red'
bg_color = 'white'
resolution = (336, 336)
max_pic_num = 500

image_root = 'images'
species_font_color_dir = 'numerical'

for fc in font_colors.keys():
    if os.path.exists(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc)):
        shutil.rmtree(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))
    os.mkdir(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))

for i in range(max_pic_num):

    num1 = random.randint(1, 1000)
    num2 = random.randint(1, 1000)
    operators = ['+', '-', 'x', '÷']
    chosen_operator = random.choice(operators)
    correct_answer = 0

    if chosen_operator == '+':
        correct_answer = num1 + num2
    elif chosen_operator == '-':
        # 确保num1大于num2，这样结果不会是负数
        num1, num2 = max(num1, num2), min(num1, num2)
        correct_answer = num1 - num2
    elif chosen_operator == 'x':
        correct_answer = num1 * num2
    elif chosen_operator == '÷':
        factors = [i for i in range(1, num1 + 1) if num1 % i == 0]
        if factors:
            num2 = random.choice(factors)
            correct_answer = num1 // num2

    # 生成一个相近的错误答案
    wrong_answer = correct_answer + random.randint(-500, 500)
    # 确保wrong_answer不是负数
    wrong_answer = max(0, wrong_answer)

    equation = f"{num1} {chosen_operator} {num2} ="

    # 处理图像
    for fc, rgb in font_colors.items():
        img = add_equation(text=equation, num_prints=1, font_path=font_path, font_size=font_size, font_color=equation_font_color, bg_color=bg_color, resolution=resolution)
        img = add_wrong_answer(image=img.copy(), text=str(wrong_answer), num_prints=1, font_path=font_path, font_size=font_size, font_color=rgb)
        img.save(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc, str(i) + '-' + str(correct_answer) + '-' + str(wrong_answer) + '.jpg'))

# Generate Examples - Complex - Font Color

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_colors = {
    'red': (255, 0, 0),
    'orange': (255, 128, 0),
    'yellow': (255, 255, 0),
    'green': (0, 255, 0),
    'cyan': (0, 255, 255),
    'blue': (0, 0, 255),
    'purple': (128, 0, 255),
    'pink': (255, 0, 255),
     
    'lred': (255, 128, 128),
    'dred': (128, 0, 0),
    
    'lorange': (255, 192, 128),
    'dorange': (128, 64, 0),
    
    'lyellow': (255, 255, 128),
    'dyellow': (128, 128, 0),
    
    'lgreen': (128, 255, 128),
    'dgreen': (0, 128, 0),
    
    'lcyan': (128, 255, 255),
    'dcyan': (0, 128, 128),
    
    'lblue': (128, 128, 255),
    'dblue': (0, 0, 128),
    
    'lpurple': (192, 128, 255),
    'dpurple': (64, 0, 128),
    
    'lpink': (255, 128, 255),
    'dpink': (128, 0, 128),
    
    'grey': (128, 128, 128),
    'white': (255, 255, 255),
    'black': (0, 0, 0),
}

image_root = 'images'
species_font_color_dir = 'complex'

for fc in font_colors.keys():
    if os.path.exists(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc)):
        shutil.rmtree(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))
    os.mkdir(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc))

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
max_pic_num = 500
resolution = (336, 336)

dataset_path = 'datasets/COCO2017/val2017/'
json_path = 'datasets/aokvqa_v1p0_val.json'

# 初始化一个空字典用于存储映射
img_dict = {}

# 遍历文件夹中的文件
for filename in os.listdir(dataset_path):
    if filename.endswith('.jpg'):
            prefix = int(filename.split('.')[0])
            img_dict[prefix] = dataset_path + filename
                      
# 读取JSON文件
with open(json_path, 'r') as file:
    data = json.load(file)

# 遍历JSON数据
count = 0
for i, item in enumerate(data):
    
    if count >= max_pic_num:
        break
    
    image_id = item['image_id']
    question = item['question']
    choices = item['choices']
    correct_choice_idx = item['correct_choice_idx']
    
    # 有的问题里会存在/影响路径识别
    if '/' in question:
        continue
    
    image_path = img_dict[image_id]
    img = Image.open(image_path)
    img = img.resize(resolution)
    
    label = choices[correct_choice_idx].split('/')[0]   # 有的选项里会存在/影响路径识别
    mislabel = random.choice(choices).split('/')[0]
    while label == mislabel:
        mislabel = random.choice(choices).split('/')[0]
    
    # 处理图像
    try:
        for fc, rgb in font_colors.items():
            img_mislabel = add_text_img(image=img.copy(), text=mislabel, font_path=font_path, font_size=font_size, font_color=rgb)
            img_mislabel.save(os.path.join(image_root, species_font_color_dir + '-' + 'fc' + fc, str(count) + '-' + question + '-' + label + '-' + mislabel + '.jpg'))
    except Exception as e:
        continue
    count+=1

# Generate Examples - Species - Transparency

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    image = image.convert("RGBA")
    txt = Image.new('RGBA', image.size, (255,255,255,0))

    font = ImageFont.truetype(font_path, font_size)
    d = ImageDraw.Draw(txt)    

    image_width = image.width
    image_height = image.height

    text_width = int(d.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    d.text((text_x, text_y), text, fill=font_color, font=font)
    combined = Image.alpha_composite(image, txt)    

    return combined

font_path = 'fonts/arial_bold.ttf'
font_size = 15
resolution = (336, 336)

dataset = ImageNet(root='datasets/ImageNet/', split='val', transform=transforms.Compose([transforms.Resize(resolution)]))

image_root = 'images'

# 将0到255的透明度值均匀划分
num_segments = 5
segment_size = 255 // num_segments

# 生成包含五个透明度值的列表
alpha_values = [(i+1) * segment_size for i in range(num_segments)]
specie_transparency_dir = 'species'

for value in alpha_values:
    if os.path.exists(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value))):
        shutil.rmtree(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))
    os.mkdir(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))

# 定义阈值，达到这个次数后才继续下一标签的图片输入，使用字典来跟踪标签出现的次数
threshold = 1
duplicate = {}
max_num = 500
count = 0

# 遍历前 num 张图像和标签
for i, (image, label) in enumerate(dataset):

    if count >= max_num:
        break

    # 从类别 index 获取图像的类别名
    label = dataset.classes[label][0]

    # 检查标签是否在字典中
    if label in duplicate:
        # 如果标签出现次数未达到阈值，增加计数
        if duplicate[label] < threshold:
            duplicate[label] += 1
        else:
            # 如果标签已经达到阈值，继续下一张图片
            continue
    else:
        # 如果标签不在字典中，添加并初始化计数
        duplicate[label] = 1

    # 任意选择其它种类
    mislabel = random.choice(dataset.classes)[0]

    # 处理图像
    for value in alpha_values:
        img = add_text_img(image=image.copy(), text=mislabel, font_path=font_path, font_size=font_size, font_color=(255, 255, 255, value))
        img.save(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value), str(i) + '-' + label + '-' + mislabel + '.png'))

    count += 1

# Generate Examples - Color - Transparency

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    image = image.convert("RGBA")
    txt = Image.new('RGBA', image.size, (255,255,255, 0))

    font = ImageFont.truetype(font_path, font_size)
    d = ImageDraw.Draw(txt)    

    image_width = image.width
    image_height = image.height

    text_width = int(d.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    d.text((text_x, text_y), text, fill=font_color, font=font)
    combined = Image.alpha_composite(image, txt)    

    return combined

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

image_root = 'images'
specie_transparency_dir = 'color'

# 将0到255的透明度值均匀划分
num_segments = 5
segment_size = 255 // num_segments

# 生成包含五个透明度值的列表
alpha_values = [(i+1) * segment_size for i in range(num_segments)]

for value in alpha_values:
    if os.path.exists(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value))):
        shutil.rmtree(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))
    os.mkdir(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))

# 读取文件内容
with open('datasets/DAQUAR_QA.txt', 'r') as file:
    text = file.read()

lines = text.strip().split('\n')
color_questions = []
for i, line in enumerate(lines):
    if "what color" in line:
        question = re.sub(r'image\d+', 'image', line)
        image_name = re.search(r'image\d+', line).group(0)
        color = lines[i + 1]
        
        # 检查颜色行中是否存在多个颜色
        if color.count(',') > 0:
            continue
        
        if len(color) <= 2:
            continue
        
        color_questions.append((question, image_name, color))

colors = set()
for _, _, color in color_questions:
    colors.add(color.strip())
colors = list(colors)

for question, img, color in color_questions:

    image_path = f"datasets/DAQUAR/{img}.png"
    image = Image.open(image_path)
        
    miscolor = random.choice(colors)
    while miscolor == color:
        miscolor = random.choice(colors)
        
    # 处理图像
    for value in alpha_values:
        img = add_text_img(image=image.copy(), text=miscolor, font_path=font_path, font_size=font_size, font_color=(255, 255, 255, value))
        img.save(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value), str(i) + '-' + question + '-' + color + '-' + miscolor + '.png'))

# Generate Examples - Counting - Transparency

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    image = image.convert("RGBA")
    txt = Image.new('RGBA', image.size, (255,255,255, 0))

    font = ImageFont.truetype(font_path, font_size)
    d = ImageDraw.Draw(txt)    

    image_width = image.width
    image_height = image.height

    text_width = int(d.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    d.text((text_x, text_y), text, fill=font_color, font=font)
    combined = Image.alpha_composite(image, txt)    

    return combined

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)

image_root = 'images'
specie_transparency_dir = 'counting'

# 将0到255的透明度值均匀划分
num_segments = 5
segment_size = 255 // num_segments

# 生成包含五个透明度值的列表
alpha_values = [(i+1) * segment_size for i in range(num_segments)]

for value in alpha_values:
    if os.path.exists(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value))):
        shutil.rmtree(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))
    os.mkdir(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))

number_to_word = {
    '1': 'one',
    '2': 'two',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten'
}

countbench_dir = 'datasets/LabeledCountBench/'

# 遍历处理文件夹中的图像文件
for i, filename in enumerate(os.listdir(countbench_dir)):
    if filename.endswith(('.jpg', '.png', '.jpeg', '.gif')):  

        image_path = os.path.join(countbench_dir, filename)
        image = Image.open(image_path)
        image = image.resize(resolution)
        
        label = filename.split('.')[0].split('_')[-2]
        counting = filename.split('.')[0].split('_')[-1]
        
        miscounting = str(random.randint(1, 10))
        while miscounting == counting:
            miscounting = str(random.randint(1, 10))

        counting = number_to_word.get(counting, 'Invalid')
        miscounting = number_to_word.get(miscounting, 'Invalid')

        # 处理图像
        for value in alpha_values:
            img = add_text_img(image=image.copy(), text=miscounting, font_path=font_path, font_size=font_size, font_color=(255, 255, 255, value))
            img.save(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value), str(i) + '-' + label + '-' + counting + '-' + miscounting + '.png'))

# Generate Examples - Numerical - Transparency

In [None]:
def add_equation(text, font_path, num_prints, font_size, font_color, bg_color, resolution):

    image = Image.new('RGB', resolution, bg_color)
    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

def add_wrong_answer(image, text, num_prints, font_path, font_size, font_color):

    image = image.convert("RGBA")
    txt = Image.new('RGBA', image.size, (255,255,255, 0))

    font = ImageFont.truetype(font_path, font_size)
    d = ImageDraw.Draw(txt)    

    max_x = image.width
    max_y = image.height

    text_width = int(d.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2 - 100  # 在图片中心的正上方100pixel处

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        d.text(pos, text, fill=font_color, font=font)

    combined = Image.alpha_composite(image, txt)    

    return combined

font_path = 'fonts/arial_bold.ttf'
font_size = 15
equation_font_color = 'black'
mislabel_font_color = 'red'
bg_color = 'white'
resolution = (336, 336)
max_pic_num = 500

image_root = 'images'
specie_transparency_dir = 'numerical'

# 将0到255的透明度值均匀划分
num_segments = 5
segment_size = 255 // num_segments

# 生成包含五个透明度值的列表
alpha_values = [(i+1) * segment_size for i in range(num_segments)]

for value in alpha_values:
    if os.path.exists(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value))):
        shutil.rmtree(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))
    os.mkdir(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))

for i in range(max_pic_num):

    num1 = random.randint(1, 1000)
    num2 = random.randint(1, 1000)
    operators = ['+', '-', 'x', '÷']
    chosen_operator = random.choice(operators)
    correct_answer = 0

    if chosen_operator == '+':
        correct_answer = num1 + num2
    elif chosen_operator == '-':
        # 确保num1大于num2，这样结果不会是负数
        num1, num2 = max(num1, num2), min(num1, num2)
        correct_answer = num1 - num2
    elif chosen_operator == 'x':
        correct_answer = num1 * num2
    elif chosen_operator == '÷':
        factors = [i for i in range(1, num1 + 1) if num1 % i == 0]
        if factors:
            num2 = random.choice(factors)
            correct_answer = num1 // num2

    # 生成一个相近的错误答案
    wrong_answer = correct_answer + random.randint(-500, 500)
    # 确保wrong_answer不是负数
    wrong_answer = max(0, wrong_answer)

    equation = f"{num1} {chosen_operator} {num2} ="

    # 处理图像
    for value in alpha_values:
        img = add_equation(text=equation, num_prints=1, font_path=font_path, font_size=font_size, font_color=equation_font_color, bg_color=bg_color, resolution=resolution)
        img = add_wrong_answer(image=img.copy(), text=str(wrong_answer), num_prints=1, font_path=font_path, font_size=font_size, font_color=(255, 0, 0, value))
        img.save(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value), str(i) + '-' + str(correct_answer) + '-' + str(wrong_answer) + '.png'))

# Generate Examples - Complex - Transparency

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    image = image.convert("RGBA")
    txt = Image.new('RGBA', image.size, (255,255,255, 0))

    font = ImageFont.truetype(font_path, font_size)
    d = ImageDraw.Draw(txt)    

    image_width = image.width
    image_height = image.height

    text_width = int(d.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    d.text((text_x, text_y), text, fill=font_color, font=font)
    combined = Image.alpha_composite(image, txt)    

    return combined

font_path = 'fonts/arial_bold.ttf'
font_size = 15
font_color = 'white'
resolution = (336, 336)
max_pic_num = 500

image_root = 'images'
specie_transparency_dir = 'complex'

# 将0到255的透明度值均匀划分
num_segments = 5
segment_size = 255 // num_segments

# 生成包含五个透明度值的列表
alpha_values = [(i+1) * segment_size for i in range(num_segments)]

for value in alpha_values:
    if os.path.exists(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value))):
        shutil.rmtree(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))
    os.mkdir(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value)))

dataset_path = 'datasets/COCO2017/val2017/'
json_path = 'datasets/aokvqa_v1p0_val.json'

# 初始化一个空字典用于存储映射
img_dict = {}

# 遍历文件夹中的文件
for filename in os.listdir(dataset_path):
    if filename.endswith('.jpg'):
            prefix = int(filename.split('.')[0])
            img_dict[prefix] = dataset_path + filename
                      
# 读取JSON文件
with open(json_path, 'r') as file:
    data = json.load(file)

# 遍历JSON数据
for i, item in enumerate(data):
    
    if i >= max_pic_num:
        break
    
    image_id = item['image_id']
    question = item['question']
    choices = item['choices']
    correct_choice_idx = item['correct_choice_idx']
    
    # 有的问题里会存在/影响路径识别
    if '/' in question:
        continue
    
    image_path = img_dict[image_id]
    img = Image.open(image_path)
    img = img.resize(resolution)
    
    label = choices[correct_choice_idx].split('/')[0]   # 有的选项里会存在/影响路径识别
    mislabel = random.choice(choices).split('/')[0]
    while label == mislabel:
        mislabel = random.choice(choices).split('/')[0]
    
    # 处理图像
    for value in alpha_values:
        img_mislabel = add_text_img(image=img.copy(), text=mislabel, font_path=font_path, font_size=font_size, font_color=(255, 255, 255, value))
        img_mislabel.save(os.path.join(image_root, specie_transparency_dir + '-' + 't' + str(value), str(i) + '-' + question + '-' + label + '-' + mislabel + '.png'))

# Generate Examples - Species - Font Size

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_color = 'white'
resolution = (336, 336)

max_fontsize = 15
num_segments = 5
segment_size = max_fontsize // num_segments

dataset = ImageNet(root='datasets/ImageNet/', split='val', transform=transforms.Compose([transforms.Resize(resolution)]))

image_root = 'images'

# 生成包含五个字体大小的列表
font_sizes = [(i+1) * segment_size for i in range(num_segments)]
specie_fontsize_dir = 'species'

for fs in font_sizes:
    if os.path.exists(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs))):
        shutil.rmtree(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))
    os.mkdir(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))

# 定义阈值，达到这个次数后才继续下一标签的图片输入，使用字典来跟踪标签出现的次数
threshold = 1
duplicate = {}
max_num = 500
count = 0

# 遍历前 num 张图像和标签
for i, (image, label) in enumerate(dataset):

    if count >= max_num:
        break

    # 从类别 index 获取图像的类别名
    label = dataset.classes[label][0]

    # 检查标签是否在字典中
    if label in duplicate:
        # 如果标签出现次数未达到阈值，增加计数
        if duplicate[label] < threshold:
            duplicate[label] += 1
        else:
            # 如果标签已经达到阈值，继续下一张图片
            continue
    else:
        # 如果标签不在字典中，添加并初始化计数
        duplicate[label] = 1

    # 任意选择其它种类
    mislabel = random.choice(dataset.classes)[0]

    # 处理图像
    for fs in font_sizes:
        img = add_text_img(image=image.copy(), text=mislabel, font_path=font_path, font_size=fs, font_color=font_color)
        img.save(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs), str(i) + '-' + label + '-' + mislabel + '.jpg'))

    count += 1

# Generate Examples - Color - Font Size

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_color = 'white'
resolution = (336, 336)

max_fontsize = 15
num_segments = 5
segment_size = max_fontsize // num_segments

image_root = 'images'
specie_fontsize_dir = 'color'

# 生成包含五个字体大小的列表
font_sizes = [(i+1) * segment_size for i in range(num_segments)]

for fs in font_sizes:
    if os.path.exists(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs))):
        shutil.rmtree(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))
    os.mkdir(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))

# 读取文件内容
with open('datasets/DAQUAR_QA.txt', 'r') as file:
    text = file.read()

lines = text.strip().split('\n')
color_questions = []
for i, line in enumerate(lines):
    if "what color" in line:
        question = re.sub(r'image\d+', 'image', line)
        image_name = re.search(r'image\d+', line).group(0)
        color = lines[i + 1]
        
        # 检查颜色行中是否存在多个颜色
        if color.count(',') > 0:
            continue
        
        if len(color) <= 2:
            continue
        
        color_questions.append((question, image_name, color))

colors = set()
for _, _, color in color_questions:
    colors.add(color.strip())
colors = list(colors)

for question, img, color in color_questions:

    image_path = f"datasets/DAQUAR/{img}.png"
    image = Image.open(image_path)
        
    miscolor = random.choice(colors)
    while miscolor == color:
        miscolor = random.choice(colors)
    
    # 处理图像
    for fs in font_sizes:
        img = add_text_img(image=image.copy(), text=miscolor, font_path=font_path, font_size=fs, font_color=font_color)
        img.save(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs), str(i) + '-' + question + '-' + color + '-' + miscolor + '.jpg'))


# Generate Examples - Counting - Font Size

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_color = 'white'
resolution = (336, 336)

max_fontsize = 15
num_segments = 5
segment_size = max_fontsize // num_segments

image_root = 'images'
specie_fontsize_dir = 'counting'

# 生成包含五个字体大小的列表
font_sizes = [(i+1) * segment_size for i in range(num_segments)]

for fs in font_sizes:
    if os.path.exists(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs))):
        shutil.rmtree(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))
    os.mkdir(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))

number_to_word = {
    '1': 'one',
    '2': 'two',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten'
}

countbench_dir = 'datasets/LabeledCountBench/'

# 遍历处理文件夹中的图像文件
for i, filename in enumerate(os.listdir(countbench_dir)):
    if filename.endswith(('.jpg', '.png', '.jpeg', '.gif')):  

        image_path = os.path.join(countbench_dir, filename)
        image = Image.open(image_path)
        image = image.resize(resolution)
        
        label = filename.split('.')[0].split('_')[-2]
        counting = filename.split('.')[0].split('_')[-1]
        
        miscounting = str(random.randint(1, 10))
        while miscounting == counting:
            miscounting = str(random.randint(1, 10))

        counting = number_to_word.get(counting, 'Invalid')
        miscounting = number_to_word.get(miscounting, 'Invalid')

        # 处理图像
        for fs in font_sizes:
            img = add_text_img(image=image.copy(), text=miscounting, font_path=font_path, font_size=fs, font_color=font_color)
            img.save(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs), str(i) + '-' + label + '-' + counting + '-' + miscounting + '.jpg'))

# Generate Examples - Numerical - Font Size

In [None]:
def add_equation(text, font_path, num_prints, font_size, font_color, bg_color, resolution):

    image = Image.new('RGB', resolution, bg_color)
    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2   # 确保多行文本都在图片中心

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

def add_wrong_answer(image, text, num_prints, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    max_x = image.width
    max_y = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 30

    start_x = (max_x - text_width) // 2
    start_y = (max_y - num_prints * text_height) // 2 - 100  # 在图片中心的正上方100pixel处

    positions = [(start_x, start_y + i * text_height) for i in range(num_prints)]

    for pos in positions:
        draw.text(pos, text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_size = 15
equation_font_color = 'black'
mislabel_font_color = 'red'
bg_color = 'white'
resolution = (336, 336)
max_pic_num = 500

max_fontsize = 15
num_segments = 5
segment_size = max_fontsize // num_segments

image_root = 'images'
specie_fontsize_dir = 'numerical'

# 生成包含五个字体大小的列表
font_sizes = [(i+1) * segment_size for i in range(num_segments)]

for fs in font_sizes:
    if os.path.exists(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs))):
        shutil.rmtree(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))
    os.mkdir(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))

for i in range(max_pic_num):

    num1 = random.randint(1, 1000)
    num2 = random.randint(1, 1000)
    operators = ['+', '-', 'x', '÷']
    chosen_operator = random.choice(operators)
    correct_answer = 0

    if chosen_operator == '+':
        correct_answer = num1 + num2
    elif chosen_operator == '-':
        # 确保num1大于num2，这样结果不会是负数
        num1, num2 = max(num1, num2), min(num1, num2)
        correct_answer = num1 - num2
    elif chosen_operator == 'x':
        correct_answer = num1 * num2
    elif chosen_operator == '÷':
        factors = [i for i in range(1, num1 + 1) if num1 % i == 0]
        if factors:
            num2 = random.choice(factors)
            correct_answer = num1 // num2

    # 生成一个相近的错误答案
    wrong_answer = correct_answer + random.randint(-500, 500)
    # 确保wrong_answer不是负数
    wrong_answer = max(0, wrong_answer)

    equation = f"{num1} {chosen_operator} {num2} ="

    # 处理图像
    for fs in font_sizes:
        img = add_equation(text=equation, num_prints=1, font_path=font_path, font_size=font_size, font_color=equation_font_color, bg_color=bg_color, resolution=resolution)
        img = add_wrong_answer(image=img.copy(), text=str(wrong_answer), num_prints=1, font_path=font_path, font_size=fs, font_color=mislabel_font_color)
        img.save(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs), str(i) + '-' + str(correct_answer) + '-' + str(wrong_answer) + '.jpg'))

# Generate Examples - Complex - Font Size

In [None]:
def add_text_img(image, text, font_path, font_size, font_color):

    draw = ImageDraw.Draw(image)

    font = ImageFont.truetype(font_path, font_size)

    image_width = image.width
    image_height = image.height

    text_width = int(draw.textlength(text, font=font))
    text_height = 10

    text_x = max(0, (image_width - text_width) // 2)    # 考虑文本宽度
    text_y = max(0, (image_height - text_height) // 2)  # 考虑文本高度

    draw.text((text_x, text_y), text, fill=font_color, font=font)

    return image

font_path = 'fonts/arial_bold.ttf'
font_color = 'white'
resolution = (336, 336)
max_pic_num = 500

max_fontsize = 15
num_segments = 5
segment_size = max_fontsize // num_segments

image_root = 'images'
specie_fontsize_dir = 'complex'

# 生成包含五个字体大小的列表
font_sizes = [(i+1) * segment_size for i in range(num_segments)]

for fs in font_sizes:
    if os.path.exists(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs))):
        shutil.rmtree(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))
    os.mkdir(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs)))

dataset_path = 'datasets/COCO2017/val2017/'
json_path = 'datasets/aokvqa_v1p0_val.json'

# 初始化一个空字典用于存储映射
img_dict = {}

# 遍历文件夹中的文件
for filename in os.listdir(dataset_path):
    if filename.endswith('.jpg'):
            prefix = int(filename.split('.')[0])
            img_dict[prefix] = dataset_path + filename
                      
# 读取JSON文件
with open(json_path, 'r') as file:
    data = json.load(file)

# 遍历JSON数据
for i, item in enumerate(data):
    
    if i >= max_pic_num:
        break
    
    image_id = item['image_id']
    question = item['question']
    choices = item['choices']
    correct_choice_idx = item['correct_choice_idx']
    
    # 有的问题里会存在/影响路径识别
    if '/' in question:
        continue
    
    image_path = img_dict[image_id]
    img = Image.open(image_path)
    img = img.resize(resolution)
    
    label = choices[correct_choice_idx].split('/')[0]   # 有的选项里会存在/影响路径识别
    mislabel = random.choice(choices).split('/')[0]
    while label == mislabel:
        mislabel = random.choice(choices).split('/')[0]
    
    # 处理图像
    for fs in font_sizes:
        img_mislabel = add_text_img(image=img.copy(), text=mislabel, font_path=font_path, font_size=fs, font_color=font_color)
        img_mislabel.save(os.path.join(image_root, specie_fontsize_dir + '-' + 'fs' + str(fs), str(i) + '-' + question + '-' + label + '-' + mislabel + '.jpg'))