#### **Import Libraries**

In [None]:
# Change directory to your root path of your project
%cd D:/AIC2024

In [None]:
import os
import json
from tqdm.auto import tqdm
from extra.TransNetV2.inference.transnetv2 import TransNetV2

#### **Shot Extraction Model**

In [None]:
class ShotExtractor():
    def __init__(self, model, video_dir='./AIC_video'):
        self.model = model
        self.video_dir = video_dir
        self.parse_video_path()

    def parse_video_path(self):
        self.all_video_paths = dict()
        for part in sorted(os.listdir(self.video_dir)):
            data_part = part.split('_')[-1]
            self.all_video_paths[data_part] = dict()
        for data_part in sorted(self.all_video_paths.keys()):
            data_part_path = f'{self.video_dir}/Videos_{data_part}/video'
            video_paths = sorted(os.listdir(data_part_path))
            video_ids = [video_path.replace('.mp4', '').split('_')[-1] for video_path in video_paths]
            for video_id, video_path in zip(video_ids, video_paths):
                video_path_full = f'{data_part_path}/{video_path}'
                self.all_video_paths[data_part][video_id] = video_path_full

    def __call__(self, save_dir='SceneJSON'):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        for data_part, video_path_dict in self.all_video_paths.items():
            video_ids = video_path_dict.keys()
            for video_id in tqdm(video_ids, desc=f'Shot Extracting {data_part}'):
                video_path = video_path_dict[video_id]
                _, single_frame_predictions, _ = self.model.predict_video(video_path)
                scenes = self.model.predictions_to_scenes(single_frame_predictions)
                os.makedirs(f"{save_dir}/{data_part}", exist_ok=True)
                with open(f"{save_dir}/{data_part}/{video_id}.json", 'w') as f:
                    json.dump(scenes.tolist(), f)

#### **Inference**

In [None]:
# Change directory to your root path of dataset directory in your project
%cd D:/AIC2024/dataset

In [None]:
model = TransNetV2()
shot_extractor = ShotExtractor(model, video_dir='./AIC_video')

In [None]:
save_dir = "./SceneJSON"
shot_extractor(save_dir=save_dir)