In [1]:
import random
import numpy as np
from PIL import Image, ImageDraw

def augment_image_sequence(image_sequence, num_frames, mix_alpha=0.5, mask_probability=0.3):
    """
    对长度为 B*N 的图像序列进行增强。
    
    Parameters:
    - image_sequence: list of list of PIL.Image objects (B sequences, each containing N frames)
    - num_frames: int, 每个序列包含的帧数 (N)
    - mix_alpha: float, 混合帧时的加权比例（默认为 0.5）
    - mask_probability: float, 遮挡的概率（默认为 0.3）

    Returns:
    - augmented_sequences: list of list of PIL.Image objects (增强后的图像序列)
    """
    augmented_sequences = []
    
    for sequence in image_sequence:  # 遍历每个序列 (B)
        augmented_sequence = sequence.copy()  # 深拷贝，避免原数据被修改
        
        # 1. 时间扰动（缺失帧 + 重复帧补充）
        missing_indices = random.sample(range(num_frames - 1), k=random.randint(1, (num_frames - 1) // 3))
        for idx in missing_indices:
            replacement_idx = random.randint(0, num_frames - 2)  # 从剩余帧随机挑选补充帧
            augmented_sequence[idx] = sequence[replacement_idx].copy()
        
        # 2. 小范围时间顺序调换
        swap_start = random.randint(0, num_frames - 3)  # 随机选择起始位置
        swap_end = min(swap_start + random.randint(2, 4), num_frames - 2)  # 限制范围不超过 N-1
        sub_seq = augmented_sequence[swap_start:swap_end]
        random.shuffle(sub_seq)  # 打乱顺序
        augmented_sequence[swap_start:swap_end] = sub_seq
        
        # 3. 混合帧（Frame Mixing）
        for i in range(num_frames - 2):  # 只对前 N-2 的帧进行混合
            next_idx = (i + 1) % (num_frames - 1)  # 下一个帧索引（环形）
            mixed_frame = Image.blend(sequence[i], sequence[next_idx], alpha=mix_alpha)
            augmented_sequence[i] = mixed_frame
        
        # 4. 遮挡（Temporal Masking）
        for idx in range(num_frames - 1):  # 遍历前 N-1 帧
            if random.random() < mask_probability:
                augmented_sequence[idx] = apply_temporal_mask(augmented_sequence[idx])
        
        # 将增强后的序列添加到结果中
        augmented_sequences.append(augmented_sequence)
    
    return augmented_sequences

def apply_temporal_mask(image):
    """
    为图像添加随机遮挡（遮挡为矩形区域）。
    
    Parameters:
    - image: PIL.Image object, 输入图像。
    
    Returns:
    - masked_image: PIL.Image object, 添加遮挡后的图像。
    """
    draw = ImageDraw.Draw(image)
    width, height = image.size
    mask_width = random.randint(width // 8, width // 4)  # 遮挡宽度
    mask_height = random.randint(height // 8, height // 4)  # 遮挡高度
    top_left_x = random.randint(0, width - mask_width)
    top_left_y = random.randint(0, height - mask_height)
    bottom_right_x = top_left_x + mask_width
    bottom_right_y = top_left_y + mask_height
    
    # 绘制黑色矩形
    draw.rectangle([top_left_x, top_left_y, bottom_right_x, bottom_right_y], fill=(0, 0, 0))
    return image

# 示例：使用图像增强代码
if __name__ == "__main__":
    # 假设有一个长度为 B*N 的图像序列
    B, N = 4, 10  # 批大小 B，序列长度 N
    image_sequences = [[Image.new("RGB", (224, 224), (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
                        for _ in range(N)] for _ in range(B)]
    
    # 对序列进行增强
    augmented_sequences = augment_image_sequence(image_sequences, num_frames=N)

    # 显示增强后的序列
    for b_idx, sequence in enumerate(augmented_sequences):
        print(f"Batch {b_idx + 1}:")
        for t_idx, frame in enumerate(sequence):
            frame.show(title=f"Frame {t_idx + 1}")

Batch 1:


display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.

Batch 2:
Batch 3:
Batch 4:


display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.


display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
display: unable to open X server `' @ error/display.c/DisplayImageCommand/410.
