In [1]:
import pandas as pd
import numpy as np
import cv2
import os
from sklearn.utils import shuffle
from tqdm import tqdm
import random
from PIL import Image

In [2]:
from sklearn.model_selection import train_test_split

In [3]:
random.seed(456)

In [4]:
# 定义图像增强函数
def random_rotation(img):
    angle = random.randint(-30, 30)
    rows, cols = img.shape[:2]
    M = cv2.getRotationMatrix2D((cols/2, rows/2), angle, 1)
    img = cv2.warpAffine(img, M, (cols, rows))
    return img

def random_scale(img):
    scale = random.uniform(0.8, 1.2)
    img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    return img

def gaussian_blur(img):
    return cv2.GaussianBlur(img, (5, 5), 0)

def color_perturbation(img):
    img = img.astype(np.float32)
    img += np.random.normal(0, 10, img.shape)
    img = np.clip(img, 0, 255)
    return img.astype(np.uint8)

def augment_image(img):
    img = random_rotation(img)
    img = random_scale(img)
    img = gaussian_blur(img)
    img = color_perturbation(img)
    return img


In [5]:
def load_data_to_dataframe(data_dir):
    data = []
    for filename in os.listdir(data_dir):
        if filename.endswith(".jpg"):
            try:
                # 文件名
                parts = filename.split('_')
                if len(parts) < 4:  # 如果分割后的部分不足4个，跳过该文件
                    print(f"跳过文件（命名不规范）: {filename}")
                    continue

                # 提取年龄、性别等信息
                age = int(parts[0])
                gender = int(parts[1])
                # ethnicity = int(parts[2])  # 如果需要种族信息，可以取消注释

                # 读取图像
                img_path = os.path.join(data_dir, filename)
                img = Image.open(img_path).convert('RGB')
                img = img.resize((224, 224))  # 调整图像尺寸
                img = np.array(img)  # 转换为numpy数组

                # 添加到数据列表
                data.append({
                    'age': age,
                    # 'ethnicity': ethnicity,  # 如果需要种族信息，可以取消注释
                    'gender': gender,
                    'img_name': filename,
                    'pixels': img
                })
            except (ValueError, IndexError) as e:
                # 如果解析失败，跳过该文件
                print(f"跳过文件（解析失败）: {filename}, 错误: {e}")
                continue

    # 创建DataFrame
    df = pd.DataFrame(data)
    return df

In [6]:
def augment_dataframe(df):
    augmented_data = []
    for idx, row in df.iterrows():
        img = row['pixels']
        img_augmented = augment_image(img)  
        augmented_data.append(img_augmented)

    df_augmented = pd.DataFrame({
        'age': df['age'],
        # 'ethnicity': df['ethnicity'],
        'gender': df['gender'],
        'img_name': df['img_name'],
        'pixels': augmented_data
    })
    return df_augmented

In [7]:
data_dir = os.path.join('UTKFace')
df = load_data_to_dataframe(data_dir)
#df_augmented = augment_dataframe(df)
# df_augmented.to_csv('augmented_data.csv', index=False)  # 保存为CSV文件


跳过文件（命名不规范）: 39_1_20170116174525125.jpg.chip.jpg
跳过文件（命名不规范）: 61_1_20170109142408075.jpg.chip.jpg
跳过文件（命名不规范）: 61_1_20170109150557335.jpg.chip.jpg


In [8]:
df.head

<bound method NDFrame.head of        age  gender                                img_name  \
0      100       0  100_0_0_20170112213500903.jpg.chip.jpg   
1      100       0  100_0_0_20170112215240346.jpg.chip.jpg   
2      100       1  100_1_0_20170110183726390.jpg.chip.jpg   
3      100       1  100_1_0_20170112213001988.jpg.chip.jpg   
4      100       1  100_1_0_20170112213303693.jpg.chip.jpg   
...    ...     ...                                     ...   
23700    9       1    9_1_3_20161220222856346.jpg.chip.jpg   
23701    9       1    9_1_3_20170104222949455.jpg.chip.jpg   
23702    9       1    9_1_4_20170103200637399.jpg.chip.jpg   
23703    9       1    9_1_4_20170103200814791.jpg.chip.jpg   
23704    9       1    9_1_4_20170103213057382.jpg.chip.jpg   

                                                  pixels  
0      [[[215, 206, 201], [213, 204, 199], [214, 205,...  
1      [[[118, 122, 134], [120, 124, 135], [122, 126,...  
2      [[[221, 222, 226], [231, 233, 238], [233,

In [9]:
#进行数据增强前先划分数据集
# 指定划分比例
train_ratio = 0.8  # 训练集比例
val_ratio = 0.2    # 验证集比例

# 划分数据集
train_df, val_df = train_test_split(df, test_size=val_ratio, random_state=456)

In [10]:
# 输出划分结果
print(f"训练集大小: {len(train_df)}")
print(f"验证集大小: {len(val_df)}")


训练集大小: 18964
验证集大小: 4741


In [11]:
#仅对训练集进行数据增强
df_augmented = augment_dataframe(train_df)
# df_augmented.to_csv('augmented_data.csv', index=False)  # 保存为CSV文件

In [12]:
df_augmented.head()

Unnamed: 0,age,gender,img_name,pixels
16430,44,0,44_0_0_20170117003404757.jpg.chip.jpg,"[[[36, 0, 0], [4, 0, 4], [0, 20, 0], [0, 2, 0]..."
2107,1,0,1_0_3_20161220220036937.jpg.chip.jpg,"[[[39, 20, 7], [53, 25, 39], [105, 61, 57], [1..."
6799,26,1,26_1_0_20170116234741431.jpg.chip.jpg,"[[[6, 3, 4], [1, 4, 0], [0, 4, 4], [0, 0, 0], ..."
4205,24,0,24_0_0_20170119152257171.jpg.chip.jpg,"[[[0, 8, 0], [1, 0, 24], [2, 9, 0], [0, 12, 0]..."
19614,56,1,56_1_0_20170104235032259.jpg.chip.jpg,"[[[0, 4, 0], [0, 7, 6], [0, 20, 0], [6, 0, 2],..."


In [13]:
%store df_augmented
%store val_df

MemoryError: 