原mnist数据集中的每个数字图片数量不一样，现在均衡数据集，使得每个标签的数量一致。

In [1]:
#显示每个label的数量

import pandas as pd
import os
def print_label_counts(csv_path):
    data = pd.read_csv(csv_path)
    label_counts = data['label'].value_counts()
    print("Label counts:")
    print(label_counts)
    return label_counts



base_dir = '/root/autodl-tmp/xin/datasets/MNIST'
train_csv_path = os.path.join(base_dir, 'train_labels.csv')
label_counts = print_label_counts(train_csv_path)


Label counts:
label
1    6742
7    6265
3    6131
2    5958
9    5949
0    5923
6    5918
8    5851
4    5842
5    5421
Name: count, dtype: int64


In [5]:

def balance_dataset_by_deleting(data, images_per_label):
    balanced_data = pd.DataFrame()
    for label in data['label'].unique():
        label_data = data[data['label'] == label]
        balanced_data = pd.concat([balanced_data, label_data.sample(n=images_per_label)])
    return balanced_data

images_per_label = label_counts.min()
data = pd.read_csv(train_csv_path)
balanced_data = balance_dataset_by_deleting(data, images_per_label)
 

# 保存均衡后的数据集
balanced_csv_path = 'balanced_train_labels.csv'
balanced_data.to_csv(balanced_csv_path, index=False)

# 确认保存成功
print("Balanced data saved to:", balanced_csv_path)
#print_label_counts(balanced_csv_path)

Balanced data saved to: balanced_train_labels.csv


给原黑白色的图片加颜色，且没有任何bias,每个数字对应的每种颜色都有1807个

In [None]:
def save_colored_images_and_update_csv(train_imgs, train_labels, colors, base_dir, csv_path):
    color_count = len(colors)
    images_per_color_per_label = 1807  # 假设每个颜色每个标签1807张图片

    # 初始化计数器
    color_usage = {label: {color: 0 for color in colors} for label in range(10)}

    for img, label in zip(train_imgs, train_labels):
        # 选择当前标签的颜色
        possible_colors = [color for color in colors if color_usage[label][color] < images_per_color_per_label]
        if not possible_colors:
            continue  # 如果当前标签的颜色已满，跳过
        chosen_color = possible_colors[0]
        color_usage[label][chosen_color] += 1

        # 保存图像，文件名包含颜色和标签信息
        img_filename = f"label_{label}_color_{colors.index(chosen_color)}_{color_usage[label][chosen_color]}.png"
        img_path = os.path.join(base_dir, img_filename)
        Image.fromarray(img).save(img_path)

        # 更新CSV文件
        with open(csv_path, 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([img_filename, label, colors.index(chosen_color)])

# 执行保存函数
save_colored_images_and_update_csv(train_imgs, train_labels, colors, base_dir, csv_path)


In [None]:
import csv
import os
import numpy as np
from PIL import Image, ImageOps
from collections import defaultdict
def apply_color_to_image(image, color):
    """
    将颜色应用到灰度图像
    参数:
    image: PIL Image对象, 灰度图像
    color: RGB元组, 如(255, 0, 0)表示红色
    返回:
    彩色图像
    """
    colored_image = ImageOps.colorize(image.convert('L'), black="black", white=color)
    return colored_image
def process_images_and_update_csv(image_dir, csv_path, output_csv_path):
    # 定义颜色
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]  # RGB颜色: 红, 绿, 蓝

    # 读取CSV文件来获取标签信息
    with open(csv_path, mode='r', newline='') as file:
        reader = csv.DictReader(file)
        image_data = list(reader)

    # 初始化颜色计数器
    label_color_counts = defaultdict(lambda: defaultdict(int))
    color_per_label_limit = 2000

    # 创建新的CSV文件并写入数据
    with open(output_csv_path, mode='w', newline='') as file:
        fieldnames = ['image_name', 'label', 'color']
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        writer.writeheader()

        for item in image_data:
            label = int(item['label'])
            image_name = item['image_name']
            image_path = os.path.join(image_dir, image_name)
            image = Image.open(image_path)

            # 分配颜色
            for color_index, color in enumerate(colors):
                if label_color_counts[label][color_index] < color_per_label_limit:
                    label_color_counts[label][color_index] += 1
                    color_name = color_index  # 色彩索引

                    # 应用颜色到图像并保存
                    colored_image = apply_color_to_image(image, color)
                    colored_image_path = os.path.join("/root/autodl-tmp/xin/datasets/MNIST/colored-train-images", f"colored_{image_name}")
                    colored_image.save(colored_image_path)

                    # 写入新的CSV记录
                    writer.writerow({'image_name': f"colored_{image_name}", 'label': label, 'color': color_name})
                    break
# 设置基本目录和文件路径
base_dir = '/root/autodl-tmp/xin/datasets/MNIST'
image_dir = os.path.join(base_dir, 'train-images')
csv_path = os.path.join(base_dir,  'balanced_train_labels.csv')
output_csv_path = os.path.join(base_dir, 'train-images', 'updated_train_labels_with_colors.csv')

# 处理图像并更新CSV
process_images_and_update_csv(image_dir, csv_path, output_csv_path)


给原黑白色的图片加颜色，有bias,每种数字都有主要的颜色

In [None]:
import os
import csv
import numpy as np
from PIL import Image
import matplotlib.colors as mcolors

# 解析MNIST数据集
def parse_mnist(minst_file_addr: str) -> np.array:
    with gzip.open(minst_file_addr, 'rb') as f:
        if "label" in minst_file_addr:
            f.read(8)  # 跳过前8个字节
            data = np.frombuffer(f.read(), dtype=np.uint8)
        else:
            f.read(16)  # 跳过前16个字节
            data = np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28)
    return data

# 创建自定义颜色映射
def create_custom_cmap(base_color):
    """
    创建一个从黑色到基色的自定义色彩映射。
    """
    colors = [(0, 0, 0), base_color]  # 黑色到基色
    cmap_name = 'custom'
    return mcolors.LinearSegmentedColormap.from_list(cmap_name, colors)

# 使用基色创建自定义色彩映射
red_cmap = create_custom_cmap((1, 0, 0))  # 红色
blue_cmap = create_custom_cmap((0, 0, 1))  # 蓝色
green_cmap = create_custom_cmap((0, 1, 0))  # 绿色
colors = [red_cmap, blue_cmap, green_cmap]
 
# 定义每个标签的颜色比例
color_ratios = {
    0: [0.5, 0.25, 0.25],  # 标签0: 50%红色 (主色), 25%蓝色, 25%绿色
    1: [0.5, 0.25, 0.25],  # 标签1: 50%红色 (主色), 25%蓝色, 25%绿色
    2: [0.5, 0.25, 0.25],  # 标签2: 50%红色 (主色), 25%蓝色, 25%绿色
    3: [0.25, 0.25, 0.5],  # 标签3: 50%绿色 (主色), 25%红色, 25%蓝色
    4: [0.25, 0.25, 0.5],  # 标签4: 50%绿色 (主色), 25%红色, 25%蓝色
    5: [0.25, 0.5, 0.25],  # 标签5: 50%蓝色 (主色), 25%红色, 25%绿色
    6: [0.25, 0.5, 0.25],  # 标签6: 50%蓝色 (主色), 25%红色, 25%绿色
    7: [0.25, 0.5, 0.25],  # 标签7: 50%蓝色 (主色), 25%红色, 25%绿色
    8: [0.25, 0.25, 0.5],  # 标签8: 50%绿色 (主色), 25%红色, 25%蓝色
    9: [0.5, 0.25, 0.25],  # 标签9: 50%红色 (主色), 25%蓝色, 25%绿色
}


# 保存彩色图像并更新CSV
def save_colored_images_and_update_csv(train_imgs, train_labels, colors, color_ratios, base_dir, csv_path):
    color_count = len(colors)
    
    # 从均衡数据集中提取标签计数
    label_counts = pd.Series(train_labels).value_counts()
    
    # 初始化计数器
    color_usage = {label: {color: 0 for color in range(color_count)} for label in range(10)}

    for img, label in zip(train_imgs, train_labels):
        # 根据颜色比例选择颜色
        total_images = label_counts[label]
        chosen_color = None
        for color_index, ratio in enumerate(color_ratios[label]):
            if color_usage[label][color_index] < total_images * ratio:
                chosen_color = color_index
                break

        if chosen_color is None:
            continue  # 如果当前标签的颜色已满，跳过

        color_usage[label][chosen_color] += 1

        # 保存图像，文件名包含颜色和标签信息
        img_filename = f"label_{label}_color_{chosen_color}_{color_usage[label][chosen_color]}.png"
        img_path = os.path.join(base_dir, img_filename)
        Image.fromarray(img).save(img_path)

        # 更新CSV文件
        with open(csv_path, 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([img_filename, label, chosen_color])

# 加载均衡后的数据集
balanced_data = pd.read_csv(os.path.join(base_dir, 'balanced_train_labels.csv'))
train_labels_balanced = balanced_data['label'].values

# 加载图像数据
train_imgs = parse_mnist("MNIST/train-images-idx3-ubyte.gz")
train_labels = parse_mnist("MNIST/train-labels-idx1-ubyte.gz")

# 只保留均衡数据集对应的图像和标签
balanced_indices = balanced_data.index.values
train_imgs_balanced = train_imgs[balanced_indices]
train_labels_balanced = train_labels[balanced_indices]

# 指定保存目录和CSV路径
output_base_dir = '/path/to/save/images'
output_csv_path = '/path/to/save/labels.csv'

# 执行保存函数
save_colored_images_and_update_csv(train_imgs_balanced, train_labels_balanced, colors, color_ratios, output_base_dir, output_csv_path)
