## 生成过滤配置文件

  生成filter_config/optimal_frames_data.json（定义视频的最佳提取帧），filter_config/poor_quality_video_list.pkl（过滤低质量视频文件），filter_config/mismatched_video_pairs.pkl（过滤不匹配视频音频文件对）

In [None]:
pip install git+https://github.com/openai/CLIP.git

In [None]:
import pickle
import os
import json
import torch
import cv2
from PIL import Image
import clip
import pandas as pd


def get_class_from_csv(video_name, csv_file):
    """从CSV文件中获取对应视频的类别"""
    ytid = video_name[:11]  # 假设视频文件名格式为 yt_id + .mp4
    df = pd.read_csv(csv_file)
    row = df[df['ytid'] == ytid]

    if not row.empty:
        return row.iloc[0]['class']
    else:
        print(f"No matching entry for {video_name} in the CSV file.")
        return None


def generate_optimal_frames_data_json(video_dir, csv_file, output_dir='filter_config'):
    # 加载 CLIP 模型和预处理函数
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model, preprocess = clip.load("ViT-B/32", device=device)

    best_frames_dict = {}

    for video_file in os.listdir(video_dir):
        video_path = os.path.join(video_dir, video_file)

        try:
            video = cv2.VideoCapture(video_path)
            total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
            print(f"Processing video: {video_file}, Total frames: {total_frames}")

            if total_frames <= 0:
                print(f"Skipping {video_file} as it has no frames.")
                continue

            # 获取与当前视频对应的类别文本
            class_text = get_class_from_csv(video_file, csv_file)
            if class_text is None:
                print(f"Skipping {video_file} due to missing class information.")
                continue

            # 缓存文本特征
            text_tokenized = clip.tokenize([class_text]).to(device)
            with torch.no_grad():
                text_features = model.encode_text(text_tokenized)

            best_frames = []
            for i in range(0, total_frames // 10):
                low = (total_frames // 10) * i
                high = min((total_frames // 10) * (i + 1), total_frames)

                frames_in_interval = []
                for frame_num in range(low, high):
                    video.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
                    ret, frame = video.read()
                    if not ret or frame is None:
                        break  # 如果读取帧失败，跳出循环

                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = Image.fromarray(frame)
                    frames_in_interval.append(frame)

                if frames_in_interval:
                    # 将图像列表转换为适合 CLIP 模型输入的张量
                    image_inputs = torch.stack([preprocess(frame) for frame in frames_in_interval]).to(device)
                    with torch.no_grad():
                        # 获取图像特征
                        image_features = model.encode_image(image_inputs)

                    # 计算图像与文本的相似度
                    similarities = (image_features @ text_features.T).squeeze()
                    best_frame_index = torch.argmax(similarities).item()

                    best_frames.append((low + best_frame_index, low + best_frame_index + 1))

            video.release()
            best_frames_dict[video_file[:-4]] = best_frames

        except Exception as e:
            print(f"Error processing {video_file}: {e}")
            if 'video' in locals():
                video.release()
            continue

    # 确保输出文件夹存在
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, 'optimal_frames_data.json')
    with open(output_path, 'w') as f:
        json.dump(best_frames_dict, f)


def generate_poor_quality_video_list_pkl(video_dir):
    low_quality_videos = []
    for video_file in os.listdir(video_dir):
        video_path = os.path.join(video_dir, video_file)
        video = cv2.VideoCapture(video_path)
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # 设分辨率小于 640x480 的为低质量视频，可根据实际需求修改
        if width < 640 or height < 480:
            low_quality_videos.append(video_file[:-4])

        video.release()

    # 确保 'filter_config' 文件夹存在
    os.makedirs('filter_config', exist_ok=True)
    with open('filter_config/poor_quality_video_list.pkl', 'wb') as f:
        pickle.dump(low_quality_videos, f)


def generate_mismatched_video_pairs_pkl(video_dir, audio_dir):
    videos = set([file[:-4] for file in os.listdir(video_dir)])
    audios = set([file[:-4] for file in os.listdir(audio_dir)])
    unmatch_videos = list(videos - (videos & audios))

    # 确保 'filter_config' 文件夹存在
    os.makedirs('filter_config', exist_ok=True)
    with open('filter_config/mismatched_video_pairs.pkl', 'wb') as f:
        pickle.dump(unmatch_videos, f)


if __name__ == "__main__":
    video_dir = "./data/VGGSound/video"
    audio_dir = "./data/VGGSound/audio"
    csv_file_path = "./data/VGGSound/vggsound2.csv"
    generate_optimal_frames_data_json(video_dir, csv_file_path)
    generate_poor_quality_video_list_pkl(video_dir)
    generate_mismatched_video_pairs_pkl(video_dir, audio_dir)