In [1]:
import cv2
import os
import numpy as np
import pandas as pd
import random
from PIL import Image
import torchvision.transforms as transforms
from collections import defaultdict

In [2]:
# Example usage
annotations = {}
video_path = os.path.join("Annotated_videos", "22-10-20_C2_06.mp4")
annotation_path = os.path.join("Annotations", "22-10-20_C2_06.csv")

In [36]:
# 查看视频帧率的函数
def get_frame_rate(video_path):
    # 打开视频文件
    cap = cv2.VideoCapture(video_path)
    # 获取帧率
    fps = cap.get(cv2.CAP_PROP_FPS)
    # 释放视频文件
    cap.release()
    return fps

In [37]:
# 查看视频帧率（60fps）
fps = get_frame_rate(video_path)
print(f"The frame rate is: {fps} frames/sec.")

The frame rate is: 59.94005994005994 frames/sec.


In [10]:
# 提取函数--从视频中提取带有padding的图像
def extract_images(video_path, annotation_path, padding=10, time_interval=10):

    # 确保保存图像的目录存在，如果不存在则创建
    save_directory = "extracted_images_original"
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    
    # 读取注释数据
    df = pd.read_csv(annotation_path)
    annotations = df.values.tolist()  # 将DataFrame转换为列表的列表


    # 读取视频
    cap = cv2.VideoCapture(video_path)
    frames = {}

    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(fps * time_interval)  # 每隔多少帧提取一次

    for annotation in annotations:
        frame_num = annotation[0]

        # 检查是否应该提取此帧
        if frame_num % frame_interval != 0:
            continue

        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()

        if not ret:
            continue

        x1, x2, y1, y2 = annotation[5], annotation[6], annotation[7], annotation[8]
        
        # 添加padding
        x1 = max(x1 - padding, 0)
        y1 = max(y1 - padding, 0)
        x2 = min(x2 + padding, frame.shape[1])
        y2 = min(y2 + padding, frame.shape[0])

        cropped_frame = frame[y1:y2, x1:x2]

        # 根据Object ID和帧号将图片保存到指定目录
        obj_id = annotation[2]
        filename = os.path.join(save_directory, "{}_{}.jpg".format(obj_id, frame_num))
        
        cv2.imwrite(filename, cropped_frame)
        frames[frame_num] = cropped_frame
        
    cap.release()
    return frames


In [11]:
# 提取图像（执行过一次了）
extract_images(video_path, annotation_path)

{0: array([[[ 70, 104, 126],
         [ 68, 102, 124],
         [ 71, 105, 127],
         ...,
         [ 34,  51,  60],
         [ 65,  85, 100],
         [108, 128, 143]],
 
        [[ 72, 106, 128],
         [ 71, 105, 127],
         [ 75, 109, 131],
         ...,
         [ 15,  32,  41],
         [ 56,  76,  91],
         [102, 122, 137]],
 
        [[ 75, 109, 131],
         [ 72, 106, 128],
         [ 73, 107, 129],
         ...,
         [  6,  23,  32],
         [ 57,  77,  92],
         [103, 123, 138]],
 
        ...,
 
        [[103, 111, 128],
         [134, 142, 159],
         [175, 183, 200],
         ...,
         [ 97, 115, 117],
         [100, 116, 116],
         [101, 117, 117]],
 
        [[ 53,  58,  76],
         [ 87,  92, 110],
         [131, 136, 154],
         ...,
         [ 98, 116, 118],
         [101, 117, 117],
         [102, 118, 118]],
 
        [[ 22,  27,  45],
         [ 44,  49,  67],
         [ 87,  92, 110],
         ...,
         [ 99, 117, 119],

In [3]:
def create_triplets_from_directory(directory, meerkats_sharing_frame, frame_distance_threshold=1200):
    
    # 列出目录中的所有文件
    files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    #print("Example files in directory:", files[:10])  # 输出目录中的前10个文件名

    # 创建一个词典，key为Object ID，value为该ID的所有图片
    object_dict = {}
    for f in files:
        object_id = int(f.split('_')[0])  # 从文件名中提取Object ID
        if object_id not in object_dict:
            object_dict[object_id] = []
        object_dict[object_id].append(f)
    #print(object_dict)
    # 打印我们为该视频找到了多少个不同的物体ID
    #print(f"Found {len(object_dict.keys())} unique object IDs.")

    combined_pairs = []  # 存储anchor-positive-negative的组合

    object_ids = list(object_dict.keys())

    #print(meerkats_sharing_frame)
    for obj_id in object_dict:
        
        if obj_id not in meerkats_sharing_frame:
            #print(obj_id)
            
            #print("The obj_id is not in meerkats_sharing_frame")
            continue  # 如果Object ID不在meerkats_sharing_frame中，则跳过

        images = object_dict[obj_id]
        #print(len(images))
        for i in range(len(images)):

            for j in range(i + 1, len(images)):
                frame_num_i = int(images[i].split('_')[1].split('.')[0])  # 从文件名中提取帧号
                frame_num_j = int(images[j].split('_')[1].split('.')[0])
                
                # 如果两个边界框的帧号相差大于给定的阈值，则为正样本
                if abs(frame_num_i - frame_num_j) >= frame_distance_threshold:
                    anchor = images[i]
                    positive = images[j]
                    anchor_id = int(anchor.split('_')[0])
                    
                    if anchor_id in meerkats_sharing_frame and meerkats_sharing_frame[anchor_id]:
                        # 从meerkats_sharing_frame随机选择一个不同的猫鼬ID，确保列表不为空
                        negative_id = random.choice(meerkats_sharing_frame[anchor_id])
                        while negative_id == anchor_id:  # 确保选择的猫鼬ID与锚点ID不同
                            negative_id = random.choice(meerkats_sharing_frame[anchor_id])

                        negative_id = int(negative_id)
                        #print()

                        if negative_id in object_dict:
                            # 从选择的猫鼬ID的图片列表中随机选择一个图片作为负样本
                            negative = random.choice(object_dict[negative_id])
                            combined_pairs.append((anchor, positive, negative))
    #print(combined_pairs)
    anchor_positive_pairs = [(anchor, positive) for anchor, positive, _ in combined_pairs]
    anchor_negative_pairs = [(anchor, negative) for anchor, _, negative in combined_pairs]

    return anchor_positive_pairs, anchor_negative_pairs


In [4]:
df = pd.read_csv("Annotations/22-10-20_C2_06.csv", header=None)
meerkats = defaultdict(list) #This stores the frames for each meerkat
current_frame = 0
for i, row in df.iterrows():
    #If not occluded and not a pup
    if row[1] == 0 and row[4] ==0:
        meerkats[row[2]] += [row[0]]


meerkats_sharing_frame = defaultdict(list) #This tells you which meerkats are in the same frame together

#For each pair of meerkats
for k in meerkats.keys():
    for j in meerkats.keys():
        #If they aren't the same meerkat and aren't already labelled as appearing in the same frames
        if j != k and j not in meerkats_sharing_frame[k]:
            #Find the number of frames they share together
            same_frames = len(set(meerkats[k]).intersection(set(meerkats[j])))
            #If they appear in more than one frame together
            if same_frames > 0:
                meerkats_sharing_frame[k] += [j]
                meerkats_sharing_frame[j] += [k]

print(meerkats_sharing_frame)

defaultdict(<class 'list'>, {0: [1, 2, 3, 4, 5, 7, 6, 8, 9, 10, 11, 12, 13], 1: [0, 2, 3, 4, 5, 7, 6, 8, 9, 10, 11, 12, 13, 14, 15], 2: [0, 1, 3, 4, 5, 7, 6, 8, 9, 10, 11, 12, 13, 14, 15, 18], 3: [0, 1, 2, 4], 4: [0, 1, 2, 3, 5, 7, 6, 8, 9, 10, 11, 12, 13, 14, 15], 5: [0, 1, 2, 4, 7, 6], 7: [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16], 6: [0, 1, 2, 4, 5, 7, 8, 9], 8: [0, 1, 2, 4, 7, 6, 9, 10], 9: [0, 1, 2, 4, 7, 6, 8, 10, 11], 10: [0, 1, 2, 4, 7, 8, 9, 11, 12, 13, 14, 15, 18, 19, 17, 23, 22], 11: [0, 1, 2, 4, 7, 9, 10, 12, 13, 14, 15], 12: [0, 1, 2, 4, 7, 10, 11, 13], 13: [0, 1, 2, 4, 7, 10, 11, 12, 14, 15, 16], 14: [1, 2, 4, 7, 10, 11, 13, 15], 15: [1, 2, 4, 7, 10, 11, 13, 14], 18: [2, 10], 16: [7, 13], 19: [10, 21, 17, 23, 22, 24, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 40, 39, 41, 42, 43, 44, 45, 47, 48, 49, 50, 51, 52, 54, 55, 53, 56, 57, 59, 61, 60, 62, 63], 17: [10, 19, 23, 22, 24, 26, 28, 29, 30, 31], 23: [10, 19, 17, 22], 22: [10, 19, 17, 23, 24, 26], 21: [19], 24: 

In [5]:
# 保存Anchor, Positive和Negative三元组到CSV文件。
def save_triplets_to_csv(anchor_positive_pairs, anchor_negative_pairs, csv_filename):

    # 保证输入三元组的长度相同
    assert len(anchor_positive_pairs) == len(anchor_negative_pairs)
    
    triplets = {
        "Anchor": [pair[0] for pair in anchor_positive_pairs],
        "Positive": [pair[1] for pair in anchor_positive_pairs],
        "Negative": [pair[1] for pair in anchor_negative_pairs]  # 使用pair[1]，因为Negative对只有两个元素
    }

    df = pd.DataFrame(triplets)
    df.to_csv(csv_filename, index=False)

In [6]:
# 创建数据
directory = 'extracted_images'
anchor_positive_pairs, anchor_negative_pairs = create_triplets_from_directory(directory, meerkats_sharing_frame)

print(len(anchor_positive_pairs))
print(len(anchor_negative_pairs))


# 输出三元组到csv文件中
csv_filename = "extracted_images.csv"
save_triplets_to_csv(anchor_positive_pairs, anchor_negative_pairs, csv_filename)

4348
4348


In [7]:
# 调整图片张量大小一致

directory = "extracted_images_all"
images = [img for img in os.listdir(directory) if img.endswith(".jpg")]

# 初始化最大的宽度和高度
max_width = 0
max_height = 0

# 第一遍：找到最大的宽度和高度
for img_name in images:
    img_path = os.path.join(directory, img_name)
    with Image.open(img_path) as img:
        width, height = img.size
        max_width = max(max_width, width)
        max_height = max(max_height, height)

# 设置转换
resize_transform = transforms.Resize((max_height, max_width))

# 第二遍：调整每张图像的大小
for img_name in images:
    img_path = os.path.join(directory, img_name)
    with Image.open(img_path) as img:
        resized_img = resize_transform(img)
        resized_img.save(img_path)  # 覆盖原始图像
