#### **Import Libraries**

In [1]:
import os
import json
import open_clip
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%cd D:/AIC2024/dataset

D:\AIC2024\dataset


#### **Parsing Data Path**

In [3]:
def parse_data_path(feature_dir='./keyframe'):
    all_feature_paths = dict()
    for feature_part in sorted(os.listdir(feature_dir)):
        all_feature_paths[feature_part] = dict()
    for feature_part in sorted(all_feature_paths.keys()):
        feature_part_path = f'{feature_dir}/{feature_part}'
        feature_paths = sorted(os.listdir(feature_part_path))
        feature_ids = [feature_path.split('.')[0] for feature_path in feature_paths]
        for feature_id, feature_path in zip(feature_ids, feature_paths):
            feature_path_full = f'{feature_part_path}/{feature_path}'
            all_feature_paths[feature_part][feature_id] = feature_path_full
    return all_feature_paths

In [4]:
all_video_paths = parse_data_path(feature_dir='./keyframe')

#### **CLIP ViT-L/14 Model**

In [9]:
model = 'ViT-L-14'
pretrained = 'laion2b_s32b_b82k'
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(model, device=device, pretrained=pretrained)

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


In [10]:
def encode_images(image_paths, batch_size):
    id2image_fps = {}
    video_features, images = [], []
    for id, image_path in enumerate(image_paths):
        id2image_fps[id] = image_path
        image = preprocess(Image.open(image_path)).unsqueeze(0)
        images.append(image)

    images = torch.cat(images, dim=0).to(device)
    with torch.no_grad():
        for start_index in range(0, images.shape[0], batch_size):
            image_features = model.encode_image(images[start_index:start_index+batch_size])
            image_features /= image_features.norm(dim=-1, keepdim=True)
            for index in range(image_features.shape[0]):
                video_features.append(image_features[index].cpu().numpy().astype(np.float32).flatten())
    return id2image_fps, video_features

#### **Keyframe Distillation**

In [11]:
def cosine_similarity(vector_a, vector_b):
    cosine_score = np.dot(vector_a, vector_b) / (np.linalg.norm(vector_a) * np.linalg.norm(vector_b))
    return cosine_score

def write_json_file(content_list, save_json_file_path):
    with open(save_json_file_path, 'w') as file:
        json.dump(content_list, file)
        
def sorted_by_id(keyframe_paths):
    id_path_keyframes = []
    for keyframe_path in keyframe_paths:
        keyframe_filename = keyframe_path.split('/')[-1]
        keyframe_id = int(keyframe_filename.split('.')[0])
        id_path_keyframes.append((keyframe_id, keyframe_path))
    sorted_id_path_keyframes = sorted(id_path_keyframes, key=lambda id_path: id_path[0])
    return [id_path[1] for id_path in sorted_id_path_keyframes]

def keyframe_distillation(video_features, id2image_fps, compare_length=2, threshold=0.9):
    distilled_ids = []
    distilled_features = []
    for feature_id, feature_vector in enumerate(video_features):
        adding_condition = True
        compare_features = distilled_features[-compare_length:]
        for compare_vetor in compare_features:
            cosine_score = cosine_similarity(compare_vetor, feature_vector)
            if cosine_score > threshold:
                adding_condition = False
                break
        if adding_condition:
            distilled_ids.append(feature_id)
            distilled_features.append(feature_vector)
    distilled_image_paths = [id2image_fps[distilled_id] for distilled_id in distilled_ids]
    return distilled_features, distilled_image_paths

#### **Inference**

In [12]:
distillation_save_dir='./distillation'
feature_save_dir = './clip/clip-vit-l14-laion2b/features'
id2image_save_dir = './clip/clip-vit-l14-laion2b/id2image'
if not os.path.exists(distillation_save_dir):
    os.makedirs(distillation_save_dir)
if not os.path.exists(feature_save_dir):
    os.makedirs(feature_save_dir)
if not os.path.exists(id2image_save_dir):
    os.makedirs(id2image_save_dir)

In [13]:
batch_size = 32
for video_part, video_path_dict in all_video_paths.items():
    video_ids = video_path_dict.keys()
    for video_id in tqdm(video_ids, desc=f'Distilling Part {video_part}'):
        
        video_id_path = video_path_dict[video_id]
        keyframe_image_paths = [video_id_path + '/' + keyframe_image_path for keyframe_image_path in os.listdir(video_id_path)]
        sorted_keyframe_image_paths = sorted_by_id(keyframe_image_paths)
        id2image_fps, video_features = encode_images(sorted_keyframe_image_paths, batch_size)

        distilled_features, distilled_keyframe_image_paths = keyframe_distillation(
            video_features=video_features,
            id2image_fps=id2image_fps,
            compare_length=2,
            threshold=0.9
        )
        
        distilled_id2image_fps = {id:keyframe_path for id, keyframe_path in enumerate(distilled_keyframe_image_paths)}
        os.makedirs(f'{feature_save_dir}/{video_part}', exist_ok=True)
        np.save(f'{feature_save_dir}/{video_part}/{video_id}.npy', video_features)

        os.makedirs(f'{id2image_save_dir}/{video_part}', exist_ok=True)
        with open(f'{id2image_save_dir}/{video_part}/{video_id}.json', 'w') as f:
            f.write(json.dumps(distilled_id2image_fps))
        
        save_part_dir = f'{distillation_save_dir}/{video_part}'
        os.makedirs(save_part_dir, exist_ok=True)
        save_json_file_path = save_part_dir + '/' + f'{video_id}.json'
        write_json_file(distilled_keyframe_image_paths, save_json_file_path)

Distilling Part L01: 100%|██████████| 31/31 [16:46:46<00:00, 1948.59s/it]  
Distilling Part L02:   0%|          | 0/31 [03:06<?, ?it/s]
