## Imports

In [None]:
!pip install git+https://github.com/openai/whisper.git
!pip install git+https://github.com/federicotorrielli/BetterWhisperX.git
!pip install pydub yt-dlp moviepy ffmpeg chonkie "chonkie[semantic]" whisperx==3.1.5 opencv-python

## Colab cuda support (run dit alleen op Google Colab)
- restart env en dan alles opnieuw runnen

In [2]:
# !pip uninstall ctranslate2==4.5.0 -y
# !pip install ctranslate2==4.4.0

## Input video downloaden (nu nog youtube)

In [None]:
import os
from pydub import AudioSegment
import subprocess

def download_and_convert_to_wav(youtube_url, output_wav_path):
  
    print(f"Downloading audio from {youtube_url}...")
    audio_file = "temp_audio.mp4"  # Temporary file
    os.system(f'yt-dlp -f "bestaudio" -o "{audio_file}" "{youtube_url}"')

    print("Converting to WAV format...")
    audio = AudioSegment.from_file(audio_file)
    audio.export(output_wav_path, format="wav")

    os.remove(audio_file)
    print(f"Conversion complete! WAV file saved to: {output_wav_path}")

def download_video(youtube_url, output_video_path):
    try:
        print(f"Downloading video from {youtube_url} (high quality video only)...")
        output_template = "temp_video.mp4"

        command = [
            "yt-dlp", "-f", "bestvideo[ext=mp4]", "-o", output_template, youtube_url
        ]

        subprocess.run(command, check=True)

        if os.path.exists(output_template):
            os.rename(output_template, output_video_path)
            print(f"Video saved as '{output_video_path}'.")
        else:
            raise FileNotFoundError("Failed to download video as MP4.")

    except subprocess.CalledProcessError as e:
        print(f"Error downloading video: {e}")

def process_links_from_file(file_path):
    """
    Reads YouTube links from a file and processes each one.
    """
    os.makedirs("wav_files", exist_ok=True)
    os.makedirs("mp4_files", exist_ok=True)

    with open(file_path, 'r') as file:
        links = file.readlines()

    for index, link in enumerate(links):
        link = link.strip()
        if link:
            output_wav_path = os.path.join("wav_files", f"output_audio_{index + 1}.wav")
            output_video_path = os.path.join("mp4_files", f"input_video_{index + 1}.mp4")
            download_and_convert_to_wav(link, output_wav_path)
            download_video(link, output_video_path)

# Bestand met YouTube-links
input_file = "youtube_links.txt"
process_links_from_file(input_file)

## Transcribe

In [None]:
import whisperx
import gc
import torch
import json
import os

# Input
wav_folder = "wav_files"
output_folder = "transcriptions"
unsupported_folder = "unsupported_language"
model_dir = "whisper-models"

# Ensure output folders exist
os.makedirs(output_folder, exist_ok=True)
os.makedirs(unsupported_folder, exist_ok=True)

def transcribe(audio_file):
    # Check system for compatibility
    if torch.cuda.is_available():
        device = "cuda"
        print("CUDA wordt gebruikt")
        compute_type = "float16"  # change to "int8" if low on GPU mem (may reduce accuracy)
        batch_size = 16  # reduce if low on GPU mem
    elif torch.backends.mps.is_available():
        device = "cpu"
        print("MPS (Apple Silicon) gebruikt")
        compute_type = "int8"
        batch_size = 8
    else:
        print("CPU gebruikt")
        device = "cpu"
        compute_type = "int8"
        batch_size = 4

    if not os.path.exists(model_dir):
        model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
    else:
        model = whisperx.load_model("./whisper-models/models--Systran--faster-whisper-large-v2/snapshots/f0fe81560cb8b68660e564f55dd99207059c092e", device, compute_type=compute_type)

    audio = whisperx.load_audio(audio_file)

    # Perform transcription with automatic language detection
    result = model.transcribe(audio, batch_size=batch_size)
    detected_language = result.get("language", "en")

    # Check if detected language is supported, otherwise move file to unsupported folder
    if detected_language not in ["en", "fr", "de", "es"]:
        print(f"Language detected as {detected_language}, moving to unsupported folder.")
        os.rename(audio_file, os.path.join(unsupported_folder, os.path.basename(audio_file)))
        return

    print(f"Detected language: {detected_language}")

   
    try:
        model_a, metadata = whisperx.load_align_model(language_code=detected_language, device=device)
        result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
        gc.collect()
        torch.cuda.empty_cache()
        del model_a
    except ValueError as e:
        print(f"Skipping alignment due to error: {e}")

    # Save as JSON
    base_filename = os.path.splitext(os.path.basename(audio_file))[0]
    output_json_path = os.path.join(output_folder, f"{base_filename}.json")
    with open(output_json_path, 'w') as f:
        json.dump(result, f, indent=2)

    print(f"Results saved to {output_json_path}")

# Process all WAV files in the folder
for filename in os.listdir(wav_folder):
    if filename.endswith(".wav"):
        audio_path = os.path.join(wav_folder, filename)
        transcribe(audio_path)

## Full text of transcription

In [None]:
import os
import json

# Input directory containing JSON files
json_folder = 'transcriptions'
output_dir = 'individual_texts'

# Maak output directory aan als deze niet bestaat
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory '{output_dir}' is ready.")

# Process each JSON file individually
for json_file in os.listdir(json_folder):
    if json_file.endswith('.json'):
        json_path = os.path.join(json_folder, json_file)
        print(f"Processing file: {json_file}")

        # Controleer of het JSON-bestand geldig is
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError) as e:
            print(f"Error: Failed to process '{json_file}'. Details: {e}")
            continue

        segments = data.get('segments', [])
        if not segments:
            print(f"Warning: No segments found in '{json_file}'.")
            continue

        # Create output text file for the individual transcript
        individual_output_path = os.path.join(output_dir, json_file.replace('.json', '.txt'))

        with open(individual_output_path, 'w', encoding='utf-8') as individual_file:
            for i, segment in enumerate(segments, start=1):
                text = segment.get('text', '').strip()
                if not text:
                    print(f"Warning: Segment {i} in '{json_file}' is empty.")
                    continue
                individual_file.write(f"{text} ")

        print(f"Transcript saved to '{individual_output_path}'.")

## Output chonkie chunks with timestamp (in JSON)

In [None]:
import os
import json
from chonkie import SDPMChunker

def load_document(file_path: str) -> str:

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"The file '{file_path}' does not exist.")
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()
    return content

def load_json(file_path: str) -> dict:

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Error: The JSON file '{file_path}' does not exist.")
    with open(file_path, 'r', encoding='utf-8') as f:
        try:
            return json.load(f)
        except json.JSONDecodeError as e:
            raise ValueError(f"Error: Failed to decode JSON file. Details: {e}")

def create_chunker(embedding_model="minishlab/potion-base-8M", chunk_size=512, min_sentences=1):

    return SDPMChunker(
        embedding_model=embedding_model,
        chunk_size=chunk_size,
        min_sentences=min_sentences
    )

def process_text_and_json(text_folder: str, json_folder: str, output_folder: str):

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for text_file in os.listdir(text_folder):
        if text_file.endswith(".txt"):
            base_name = os.path.splitext(text_file)[0]
            text_path = os.path.join(text_folder, text_file)
            json_path = os.path.join(json_folder, base_name + ".json")

            if not os.path.exists(json_path):
                print(f"Warning: No matching JSON file for {text_file}")
                continue

            text_content = load_document(text_path)
            json_data = load_json(json_path)
            segments = json_data.get('word_segments', [])

            if not segments:
                raise ValueError(f"Error: No segments found in the JSON file {json_path}.")

            word_list = [[seg.get('word', '').strip(), seg.get('start', ''), seg.get('end', '')] for seg in segments if seg.get('word', '').strip()]
            chunker = create_chunker()
            chunks = chunker.chunk(text_content)

            final_chunks = []
            current_word_index = 0
            for chunk in chunks:
                chunk_text = chunk.text
                chunk_words = chunk_text.split()
                chunk_word_data = []
                chunk_start = None
                chunk_end = None

                for chunk_word in chunk_words:
                    if current_word_index < len(word_list):
                        word_info = word_list[current_word_index]
                        if chunk_word == word_info[0]:
                            chunk_word_data.append({
                                "word": word_info[0],
                                "start": word_info[1],
                                "end": word_info[2]
                            })
                            if chunk_start is None:
                                chunk_start = word_info[1]
                            chunk_end = word_info[2]
                            current_word_index += 1
                        else:
                            raise ValueError(f"Word mismatch at chunk '{chunk_text}': Expected '{word_info[0]}', found '{chunk_word}'.")
                    else:
                        raise IndexError("Ran out of words in word_data to match with chunks.")

                final_chunks.append({
                    "text": chunk_text,
                    "start": chunk_start,
                    "end": chunk_end,
                    "words": chunk_word_data
                })

            output_json_path = os.path.join(output_folder, base_name + "_chunks.json")
            with open(output_json_path, 'w', encoding='utf-8') as f:
                json.dump({"chunks": final_chunks}, f, ensure_ascii=False, indent=4)
                print(f"Processed {text_file} and saved to {output_json_path}")

if __name__ == "__main__":
    text_folder = 'individual_texts'
    json_folder = 'transcriptions'
    output_folder = 'processed_json'
    process_text_and_json(text_folder, json_folder, output_folder)

## Video segmentation chonkie
### Het segmenteren van de video op basis van de nieuwe chunks

In [None]:
import json
import os
import subprocess

def create_segments(video_file, result):

    if not os.path.exists(result):
        raise FileNotFoundError(f"Error: The result file '{result}' does not exist.")

    with open(result, 'r', encoding='utf-8') as f:
        data = json.load(f)

    segments = data.get('chunks', [])
    if not segments:
        raise ValueError("No segments found in the JSON file.")

    video_name = os.path.splitext(os.path.basename(video_file))[0]
    output_dir = os.path.join('video_segments', video_name)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory '{output_dir}' is ready.")

    for i, segment in enumerate(segments, start=1):
        start = segment['start']
        end = segment['end']
        output_filename = f"segment_{i}_{int(start)}_{int(end)}.mp4"
        output_path = os.path.join(output_dir, output_filename)

        command = [
            "ffmpeg",
            "-y",
            "-i", video_file,
            "-ss", str(start),
            "-to", str(end),
            "-c:v", "libx264",
            "-c:a", "aac",
            output_path
        ]
        print(f"Creating segment {i}: {start} to {end} seconds for {video_file}.")
        subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        print(f"Segment {i} saved as '{output_filename}'.")

    print(f"All segments for {video_file} have been processed.")

def process_videos_in_directory(video_directory, json_directory):

    if not os.path.isdir(video_directory):
        raise NotADirectoryError(f"Error: The directory '{video_directory}' does not exist.")
    if not os.path.isdir(json_directory):
        raise NotADirectoryError(f"Error: The directory '{json_directory}' does not exist.")

    video_files = sorted(f for f in os.listdir(video_directory) if f.endswith('.mp4'))
    json_files = sorted(f for f in os.listdir(json_directory) if f.endswith('.json'))

    for video_file, json_file in zip(video_files, json_files):
        video_path = os.path.join(video_directory, video_file)
        json_path = os.path.join(json_directory, json_file)
        create_segments(video_path, json_path)

if __name__ == "__main__":
    video_directory = "mp4_files"  # Replace with your video directory
    json_directory = "processed_json"  # Replace with your JSON directory
    process_videos_in_directory(video_directory, json_directory)

## Frame Extraction

### Middle frame

In [None]:
import cv2
import os

def extract_middle_frame(video_path, output_dir):
    video_name = os.path.basename(video_path).replace(".", "_")
    segment_output_dir = os.path.join(output_dir, *video_path.split(os.sep)[-2:])
    os.makedirs(segment_output_dir, exist_ok=True)

    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))

    if total_frames == 0:
        print(f"Warning: No frames found in {video_path}")
        return

    middle_frame_idx = total_frames // 2
    vidcap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
    success, image = vidcap.read()

    if success:
        frame_path = os.path.join(segment_output_dir, f"{video_name}_middle_frame.jpg")
        cv2.imwrite(frame_path, image)
        print(f"Extracted middle frame from {video_path} to {frame_path}")
    else:
        print(f"Error: Could not read frame {middle_frame_idx} from {video_path}")

    vidcap.release()

def process_videos_in_directory(base_directory, output_directory):
    if not os.path.isdir(base_directory):
        raise NotADirectoryError(f"Error: The directory '{base_directory}' does not exist.")

    os.makedirs(output_directory, exist_ok=True)
    print(f"Output directory '{output_directory}' is ready.")

    for subdir in sorted(os.listdir(base_directory)):
        subdir_path = os.path.join(base_directory, subdir)
        if os.path.isdir(subdir_path):
            for filename in sorted(os.listdir(subdir_path)):
                if filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                    video_path = os.path.join(subdir_path, filename)
                    extract_middle_frame(video_path, output_directory)

if __name__ == "__main__":
    video_directory = "video_segments"  # Vervang dit door je video directory
    output_directory = "frames"  # Map om de frames op te slaan
    process_videos_in_directory(video_directory, output_directory)

### 1 frame per 3 seconds

In [None]:
import cv2
import os

def extract_frames(video_path, output_dir, interval_seconds=3):
    video_name = os.path.basename(video_path).replace(".", "_")
    segment_output_dir = os.path.join(output_dir, *video_path.split(os.sep)[-2:])
    os.makedirs(segment_output_dir, exist_ok=True)

    vidcap = cv2.VideoCapture(video_path)
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    if fps <= 0:
        print(f"Kan de FPS voor {video_path} niet ophalen.")
        return

    frame_interval = int(fps * interval_seconds)
    success, image = vidcap.read()
    count = 0
    frame_count = 0

    while success:
        if frame_count % frame_interval == 0:
            frame_path = os.path.join(segment_output_dir, f"{video_name}_frame_{count}.jpg")
            cv2.imwrite(frame_path, image)
            count += 1
        success, image = vidcap.read()
        frame_count += 1

    vidcap.release()
    print(f"Geëxtraheerd {count} frames uit {video_path} naar {segment_output_dir}")

def process_videos_in_directory(base_directory, output_directory, interval_seconds=3):
    if not os.path.isdir(base_directory):
        raise NotADirectoryError(f"Error: De map '{base_directory}' bestaat niet.")

    os.makedirs(output_directory, exist_ok=True)
    print(f"Output directory '{output_directory}' is gereed.")

    for subdir in sorted(os.listdir(base_directory)):
        subdir_path = os.path.join(base_directory, subdir)
        if os.path.isdir(subdir_path):
            for filename in sorted(os.listdir(subdir_path)):
                if filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                    video_path = os.path.join(subdir_path, filename)
                    extract_frames(video_path, output_directory, interval_seconds)

if __name__ == "__main__":
    video_directory = "video_segments"  # Vervang dit door jouw basisvideo-segmentenmap
    output_directory = "frames"         # Map om geëxtraheerde frames op te slaan
    process_videos_in_directory(video_directory, output_directory, interval_seconds=3)

## Object Segmentation

In [None]:
!pip install coremltools

In [None]:
import coremltools as ct
import numpy as np
from PIL import Image
import cv2
import time
import os
import shutil  # Voor het verwijderen van directories
from dataclasses import dataclass
from typing import List, Tuple

@dataclass
class Point:
    x: float
    y: float
    label: int  # 0 for background, 1 for foreground

@dataclass
class BoxPrompt:
    """Represents a bounding box by two corners."""
    x1: float
    y1: float
    x2: float
    y2: float

class AdvancedAutoSAM:
    def __init__(self, input_size=(1024, 1024)):
        """
        :param input_size: The (width, height) that the SAM image encoder expects
        """
        self.input_size = input_size
        self.original_size = None
        
        self.image_encoder = None
        self.prompt_encoder = None
        self.mask_decoder = None
        
        self.image_embeddings = None  # Will store the results of the image encoder

    def load_models(self, image_encoder_path, prompt_encoder_path, mask_decoder_path):
        """Load the three CoreML models: image encoder, prompt encoder, mask decoder."""
        start = time.time()
        self.image_encoder = ct.models.MLModel(image_encoder_path)
        self.prompt_encoder = ct.models.MLModel(prompt_encoder_path)
        self.mask_decoder = ct.models.MLModel(mask_decoder_path)
        print(f"Models loaded in {time.time()-start:.2f} sec.")

    def _resize_image_for_encoder(self, image_path: str) -> Image.Image:
        """
        Loads an image from disk, converts to RGB, resizes to self.input_size.
        """
        pil_image = Image.open(image_path).convert("RGB")
        self.original_size = pil_image.size  # (width, height) before resizing
        pil_image = pil_image.resize(self.input_size, Image.Resampling.LANCZOS)
        return pil_image

    def get_image_embedding(self, image_path: str):
        """
        Preprocess the image and run the image encoder. 
        Store the resulting embeddings for future use.
        """
        if self.image_encoder is None:
            raise ValueError("Image encoder model not loaded.")
        start = time.time()
        image = self._resize_image_for_encoder(image_path)
        self.image_embeddings = self.image_encoder.predict({"image": image})
        print(f"Image encoding done in {time.time()-start:.2f} sec.")

    def _transform_box_coords(
        self, box: BoxPrompt, original_size: Tuple[int, int]
    ) -> np.ndarray:
        """
        Scale box coordinates from original image size -> input_size
        to match the resized image used by the encoder.
        Returns shape: (1, 4) => [x1, y1, x2, y2] after scaling
        (depending on your prompt encoder's specification).
        """
        ow, oh = original_size
        tw, th = self.input_size

        scale_x = tw / float(ow)
        scale_y = th / float(oh)

        x1_s = box.x1 * scale_x
        y1_s = box.y1 * scale_y
        x2_s = box.x2 * scale_x
        y2_s = box.y2 * scale_y
        
        # Create shape (1,4)
        box_array = np.array([[x1_s, y1_s, x2_s, y2_s]], dtype=np.float32)
        return box_array

    def _transform_point_coords(
        self, points: np.ndarray, original_size: Tuple[int, int]
    ) -> np.ndarray:
        """
        Scale point coordinates from original image -> input_size
        Returns shape (1, N, 2).
        """
        ow, oh = original_size
        tw, th = self.input_size

        scale_x = tw / float(ow)
        scale_y = th / float(oh)

        # points shape: (N,2)
        points_s = points.copy()
        points_s[:, 0] *= scale_x
        points_s[:, 1] *= scale_y

        # expand to (1, N, 2)
        points_s = np.expand_dims(points_s, axis=0).astype(np.float32)
        return points_s

    def _predict_mask(self, sparse_embeddings, dense_embeddings):
        """
        Runs the mask decoder with the precomputed image_embeddings 
        plus the given prompt-encoder outputs (sparse & dense).
        Returns a [low_res_mask] of shape ~ (4x?) or (256x256?), 
        along with scores. We pick the highest scoring mask.
        """
        if self.mask_decoder is None:
            raise ValueError("Mask decoder not loaded.")

        out = self.mask_decoder.predict({
            "image_embedding": self.image_embeddings["image_embedding"],
            "sparse_embedding": sparse_embeddings,
            "dense_embedding": dense_embeddings,
            "feats_s0": self.image_embeddings["feats_s0"],
            "feats_s1": self.image_embeddings["feats_s1"],
        })

        # out["scores"] shape: (batch_size, numMasks) => typically (1,3)
        scores = out["scores"]
        best_idx = np.argmax(scores)
        low_res_mask = out["low_res_masks"][0, best_idx]  # shape e.g. (256,256)
        return low_res_mask, float(scores[0, best_idx])

    def _resize_and_binarize_mask(
        self, low_res_mask: np.ndarray, original_size: Tuple[int,int], threshold:float=0.0
    ) -> np.ndarray:
        """
        Resizes the mask from ~[256x256 or 1024x1024] back to the original image size.
        Binarizes at the given threshold (default 0.0).
        """
        ow, oh = original_size
        # OpenCV expects (width, height) => (ow, oh)
        mask_resized = cv2.resize(
            low_res_mask,
            (ow, oh),
            interpolation=cv2.INTER_LINEAR
        )
        # Binarize
        binary = (mask_resized > threshold).astype(np.uint8) * 255  # Multiply by 255 for proper visualization
        return binary

    ############################################################################
    #                  MULTI-SCALE BOUNDING BOX PROPOSAL LOGIC                 #
    ############################################################################

    def _multi_scale_bounding_box_proposals(
        self, image_path:str, scales=[1.0, 0.75, 0.5], edge_thresh=100
    ) -> List[Tuple[BoxPrompt, float]]:
        """
        Example: (Over)Simplified bounding-box proposals at multiple scales.
        
        1) For each scale in `scales`, we downsize the original image.
        2) We run a naive edge detection or threshold to find regions.
        3) We group edges via connected components or contours to get bounding boxes.
        4) We scale those bounding boxes back up to the original coordinate space.
        
        Return => list of (BoxPrompt, score), where "score" can be e.g. area or edge magnitude.
        """
        bgr_original = cv2.imread(image_path)
        if bgr_original is None:
            raise ValueError(f"Could not read image at {image_path}")
        oh, ow = bgr_original.shape[:2]

        proposals = []
        
        for s in scales:
            if s <= 0:
                continue
            # Resize
            w_s = int(ow*s)
            h_s = int(oh*s)
            small = cv2.resize(bgr_original, (w_s, h_s), interpolation=cv2.INTER_LINEAR)

            # Convert to grayscale
            gray_small = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
            # Simple edge detection
            edges = cv2.Canny(gray_small, threshold1=edge_thresh, threshold2=3*edge_thresh)
            # Alternatively: threshold => detect connected components => bounding boxes

            # Find contours for bounding boxes
            contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            for cnt in contours:
                x, y, w, h = cv2.boundingRect(cnt)
                area = w*h
                if area < 20:  # discard tiny proposals in scaled space
                    continue
                
                # Scale box back to original image coordinates
                # For scale s: (x / s, y / s) is top-left
                # (x+w)/s, (y+h)/s is bottom-right
                X1 = x / s
                Y1 = y / s
                X2 = (x + w) / s
                Y2 = (y + h) / s

                # Clip to original boundaries
                X1 = max(0, min(X1, ow-1))
                Y1 = max(0, min(Y1, oh-1))
                X2 = max(0, min(X2, ow-1))
                Y2 = max(0, min(Y2, oh-1))

                # "score" can be approximate area in original scale
                # area_in_original = (X2 - X1)*(Y2 - Y1)
                # Or use mean edges, etc. We'll keep it simple:
                box_score = float(area)

                proposals.append((BoxPrompt(X1, Y1, X2, Y2), box_score))

        return proposals

    def _merge_duplicate_boxes(
        self, boxes_with_scores: List[Tuple[BoxPrompt, float]], iou_thresh=0.5
    ) -> List[BoxPrompt]:
        """
        Non-maximum suppression (NMS) for bounding boxes.
        Sort by box score, pick highest, remove boxes that overlap (IoU>iou_thresh).
        Return final list of unique boxes.
        """
        # Convert to a list of tuples
        data = []
        for (box, score) in boxes_with_scores:
            data.append((box.x1, box.y1, box.x2, box.y2, score))
        # Sort by descending score
        data = sorted(data, key=lambda x: x[-1], reverse=True)

        final_boxes = []
        
        def iou(boxA, boxB):
            # boxA, boxB => (x1,y1,x2,y2)
            xA = max(boxA[0], boxB[0])
            yA = max(boxA[1], boxB[1])
            xB = min(boxA[2], boxB[2])
            yB = min(boxA[3], boxB[3])
            interArea = max(0, xB - xA) * max(0, yB - yA)
            areaA = (boxA[2]-boxA[0]) * (boxA[3]-boxA[1])
            areaB = (boxB[2]-boxB[0]) * (boxB[3]-boxB[1])
            union = areaA + areaB - interArea
            return interArea / union if union > 0 else 0.0

        suppressed = [False]*len(data)
        
        for i in range(len(data)):
            if suppressed[i]:
                continue
            # select box i
            bA = data[i]
            final_boxes.append(BoxPrompt(bA[0], bA[1], bA[2], bA[3]))
            
            # suppress boxes with IoU > threshold
            for j in range(i+1, len(data)):
                if suppressed[j]:
                    continue
                bB = data[j]
                if iou(bA, bB) > iou_thresh:
                    suppressed[j] = True

        return final_boxes

    ############################################################################
    #                 GENERATING MASKS FROM BOX PROMPTS                        #
    ############################################################################

    def _box_to_corners_as_prompt(
        self, box: BoxPrompt
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        For an example prompt encoder that is compiled to accept exactly 2 points, 
        we can pass the bounding box corners as two points:
          - Top-left (label=1 => foreground)
          - Bottom-right (label=1 => foreground)
        This is somewhat naive, but demonstrates how you might feed bounding boxes 
        into a 2-point (1 for each corner) prompt encoder.

        If your prompt encoder can accept a shape (1, N, 2) for an arbitrary 
        number of points, you could supply all 4 corners (some as foreground, 
        or some as background).
        """
        # We'll treat the bounding box corners as "foreground" points.
        # If your model specifically needs background points, you can adjust.
        points_np = np.array([
            [box.x1, box.y1], 
            [box.x2, box.y2]
        ], dtype=np.float32)

        labels_np = np.array([1, 1], dtype=np.int32)  # both labeled as FG

        return points_np, labels_np

    def generate_masks_from_boxes(
        self,
        boxes: List[BoxPrompt],
        original_size: Tuple[int,int],
        min_mask_area: int = 500,
        iou_merge_threshold: float = 0.8
    ) -> List[np.ndarray]:
        """
        For each bounding box, runs the prompt encoder => mask decoder => obtains a mask.
        Then merges duplicates using IoU among final masks. 
        Returns a list of binary masks in original image resolution.
        """
        if any(model is None for model in [self.image_encoder, self.prompt_encoder, self.mask_decoder]):
            raise ValueError("Models not loaded or image embedding not computed.")

        final_masks = []
        final_scores = []

        for box in boxes:
            # Convert bounding box to 2 "foreground" points
            points, labels = self._box_to_corners_as_prompt(box)
            
            # Transform points from original coords -> resized coords
            points_resized = self._transform_point_coords(points, original_size)
            labels_resized = np.expand_dims(labels, axis=0).astype(np.int32)  # shape (1,2)

            # Prompt encoder
            prompt_out = self.prompt_encoder.predict({
                "points": points_resized,
                "labels": labels_resized
            })
            sparse_embeddings = prompt_out["sparse_embeddings"]
            dense_embeddings = prompt_out["dense_embeddings"]

            # Mask decoder
            low_res_mask, score = self._predict_mask(sparse_embeddings, dense_embeddings)

            # Resize + binarize
            mask_bin = self._resize_and_binarize_mask(
                low_res_mask, original_size, threshold=0.0
            )
            area = cv2.countNonZero(mask_bin)
            if area < min_mask_area:
                continue

            # Simple deduplicate: IoU with previously accepted masks
            keep = True
            for existing_mask in final_masks:
                inter = np.logical_and(existing_mask, mask_bin).sum()
                union = np.logical_or(existing_mask, mask_bin).sum()
                iou_val = float(inter) / float(union) if union > 0 else 0.0
                if iou_val > iou_merge_threshold:
                    keep = False
                    break

            if keep:
                final_masks.append(mask_bin)
                final_scores.append(score)

        return final_masks

    ############################################################################
    #                          MAIN "AUTO" METHOD                               #
    ############################################################################

    def auto_generate_masks(
        self,
        image_path: str,
        scales=[1.0, 0.75, 0.5],
        edge_thresh=100,
        iou_box_thresh=0.5,
        min_mask_area=500,
        iou_merge_threshold: float = 0.8
    ) -> List[np.ndarray]:
        """
        HIGH-LEVEL PIPELINE:
          1) Generate bounding box proposals at multiple scales
          2) Merge duplicates via NMS
          3) Encode the full image once (image_encoder)
          4) For each bounding box => get mask
          5) Merge duplicate masks by IoU
          6) Return final list of binary masks in original resolution
        """
        # 1) multi-scale bounding box proposals
        proposals_with_scores = self._multi_scale_bounding_box_proposals(
            image_path,
            scales=scales,
            edge_thresh=edge_thresh
        )

        # 2) Non-maximum-suppression for boxes
        merged_boxes = self._merge_duplicate_boxes(
            proposals_with_scores,
            iou_thresh=iou_box_thresh
        )
        print(f"Detected {len(merged_boxes)} bounding boxes after NMS.")

        if not merged_boxes:
            print("No bounding boxes detected. Exiting mask generation.")
            return []

        # 3) Encode the full image once
        self.get_image_embedding(image_path)
        # original_size was stored in self.original_size
        orig_size = (self.original_size[0], self.original_size[1])  # (width, height)

        # 4) Generate masks from bounding boxes
        final_masks = self.generate_masks_from_boxes(
            merged_boxes,
            original_size=orig_size,
            min_mask_area=min_mask_area,
            iou_merge_threshold=iou_merge_threshold
        )
        print(f"Got {len(final_masks)} final masks after box -> mask filtering.")
        return final_masks

    def save_masks_as_color_overlay(
        self, masks: List[np.ndarray], image_path: str, output_path: str
    ):
        """
        Overlays all instance masks in random colors on the original image 
        and saves as a single composite image. 
        Each element in `masks` is a 2D binary array (H,W).
        """
        image_bgr = cv2.imread(image_path)
        if image_bgr is None:
            raise ValueError(f"Failed to load {image_path}")
        overlay = image_bgr.copy()

        for mask in masks:
            color = np.random.randint(0, 255, size=3, dtype=np.uint8)
            # Create a mask for the current instance
            mask_indices = mask > 0
            # Blend the color with the original image in the mask regions
            overlay[mask_indices] = (0.5 * overlay[mask_indices] + 0.5 * color).astype(np.uint8)

        cv2.imwrite(output_path, overlay)
        print(f"Overlay saved to {output_path} (with {len(masks)} masks).")

    def save_masks_individually(
        self, masks: List[np.ndarray], output_folder: str, base_name: str = "mask"
    ):
        """
        Saves each mask as an individual binary image in the specified folder.
        Masks are saved as PNG images with filenames like mask_1.png, mask_2.png, etc.

        :param masks: List of binary mask arrays (H, W) with values 0 or 255.
        :param output_folder: Path to the folder where masks will be saved.
        :param base_name: Base name for mask files.
        """
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
            print(f"Created output folder at {output_folder}")

        for idx, mask in enumerate(masks, start=1):
            mask_filename = f"{base_name}_{idx}.png"
            mask_path = os.path.join(output_folder, mask_filename)
            # Ensure mask is in uint8 format
            mask_uint8 = mask.astype(np.uint8)
            cv2.imwrite(mask_path, mask_uint8)
            print(f"Saved mask {idx} to {mask_path}")

    ############################################################################
    #                     SAVE SEGMENTS FROM MASKS METHOD                      #
    ############################################################################

    def save_segments_individually(
        self, masks: List[np.ndarray], image_path: str, output_folder: str, base_name: str = "segment"
    ):
        """
        Saves each image segment corresponding to the mask as a separate image in the specified folder.
        The background is set to black.

        :param masks: List of binary mask arrays (H, W) with values 0 or 255.
        :param image_path: Path to the original image.
        :param output_folder: Path to the folder where segments will be saved.
        :param base_name: Base name for segment files.
        """
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
            print(f"Created output folder at {output_folder}")

        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load {image_path}")

        for idx, mask in enumerate(masks, start=1):
            # Ensure mask is binary
            mask_bin = mask.astype(np.uint8)
            # Apply mask to image
            segment = cv2.bitwise_and(image, image, mask=mask_bin)
            # Optionally, set background to transparent by adding alpha channel
            # Not directly supported in OpenCV, but can save with background as black
            segment_filename = f"{base_name}_{idx}.png"
            segment_path = os.path.join(output_folder, segment_filename)
            cv2.imwrite(segment_path, segment)
            print(f"Saved segment {idx} to {segment_path}")

###############################################################################
#                                USAGE EXAMPLE                                #
###############################################################################

def main():
    # 1) Initialize advanced auto-SAM
    auto_sam = AdvancedAutoSAM(input_size=(1024, 1024))

    # 2) Load your CoreML models
    auto_sam.load_models(
        image_encoder_path="./models/SAM2_1LargeImageEncoderFLOAT16.mlpackage",
        prompt_encoder_path="./models/SAM2_1LargePromptEncoderFLOAT16.mlpackage",
        mask_decoder_path="./models/SAM2_1LargeMaskDecoderFLOAT16.mlpackage",
    )

    # 3) Define input and output directories
    input_frames_dir = "frames/input_video_1/"  # Base directory containing videos
    overlay_base_output_dir = "generated_overlays"
    masks_base_output_dir = "generated_masks"
    segments_base_output_dir = "generated_segments"

    # 4) Define supported image extensions
    supported_extensions = ('.jpg', '.jpeg', '.png', '.bmp')

    # 5) Clear existing output directories to replace with new outputs
    for output_dir in [overlay_base_output_dir, masks_base_output_dir, segments_base_output_dir]:
        if os.path.exists(output_dir):
            try:
                shutil.rmtree(output_dir)
                print(f"Cleared existing output directory: {output_dir}")
            except Exception as e:
                print(f"Error clearing directory {output_dir}: {e}")
                continue
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

    # 6) Iterate over all images in the input_frames_dir
    for root, dirs, files in os.walk(input_frames_dir):
        for file in files:
            if file.lower().endswith(supported_extensions):
                image_path = os.path.join(root, file)
                # Determine the relative path to maintain directory structure in outputs
                relative_path = os.path.relpath(root, input_frames_dir)
                
                # Define corresponding output directories
                overlay_output_dir = os.path.join(overlay_base_output_dir, relative_path)
                masks_output_dir = os.path.join(masks_base_output_dir, relative_path, os.path.splitext(file)[0] + "_masks")
                segments_output_dir = os.path.join(segments_base_output_dir, relative_path, os.path.splitext(file)[0] + "_segments")
                
                # Create output directories if they don't exist
                for directory in [overlay_output_dir, masks_output_dir, segments_output_dir]:
                    if not os.path.exists(directory):
                        os.makedirs(directory)
                        print(f"Created output directory: {directory}")
                
                # Define output file paths
                base_filename = os.path.splitext(file)[0]
                overlay_output_path = os.path.join(overlay_output_dir, f"{base_filename}_overlay.png")
                individual_masks_folder = masks_output_dir
                individual_segments_folder = segments_output_dir
                
                # 7) Run automatic mask generation
                print(f"\nProcessing image: {image_path}")
                t0 = time.time()
                try:
                    masks = auto_sam.auto_generate_masks(
                        image_path,
                        scales=[1.0, 0.75, 0.5],  # multi-scale
                        edge_thresh=100,         # for Canny
                        iou_box_thresh=0.3,      # stricter NMS for boxes
                        min_mask_area=1000,      # discard small masks
                        iou_merge_threshold=0.8  # merge duplicate masks that overlap >80%
                    )
                except Exception as e:
                    print(f"Error processing {image_path}: {e}")
                    continue
                duration = time.time() - t0
                print(f"Auto-mask generation took {duration:.2f} sec, produced {len(masks)} masks.")

                if not masks:
                    print("No masks generated for this image. Skipping saving steps.")
                    continue

                # 8) Save overlay for visualization
                try:
                    auto_sam.save_masks_as_color_overlay(masks, image_path, overlay_output_path)
                except Exception as e:
                    print(f"Error saving overlay for {image_path}: {e}")

                # 9) Save individual masks
                try:
                    auto_sam.save_masks_individually(masks, individual_masks_folder, base_name="mask")
                except Exception as e:
                    print(f"Error saving masks for {image_path}: {e}")

                # 10) Save individual segments
                try:
                    auto_sam.save_segments_individually(masks, image_path, individual_segments_folder, base_name="segment")
                except Exception as e:
                    print(f"Error saving segments for {image_path}: {e}")

if __name__ == "__main__":
    main()

## Image Captioning on the frames

In [None]:
import os
import json
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch


# Pad naar de hoofd directory met geëxtraheerde frames (alle video segmenten)
frames_dir = "frames/input_video_1"  # Pas dit aan naar jouw frames hoofd directory

# Output JSON bestand
output_json_path = "combined_frames.json"


def initialize_captioning_model():

    print("Initialiseren van het image captioning model...")
    model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    print("Model geïnitieerd.")
    return model, feature_extractor, tokenizer


def generate_caption(image_path, model, feature_extractor, tokenizer, device):

    try:
        # print(f"Generating caption for: {image_path}")  # Optioneel: Kan veel output genereren
        image = Image.open(image_path).convert("RGB")
        pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)

        output_ids = model.generate(pixel_values, max_length=16, num_beams=4)
        caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        print(f"Error generating caption for {image_path}: {e}")
        return ""

def process_all_frames(frames_dir, model, feature_extractor, tokenizer, device):
    combined_data = []
    supported_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff')

    if not os.path.isdir(frames_dir):
        print(f"Error: Frames directory {frames_dir} bestaat niet.")
        return combined_data

    # Itereer over alle subdirectories (video segmenten)
    video_segments = [d for d in os.listdir(frames_dir) if os.path.isdir(os.path.join(frames_dir, d))]
    total_segments = len(video_segments)
    print(f"Found {total_segments} video segments in '{frames_dir}'.")

    for seg_idx, segment in enumerate(sorted(video_segments), start=1):
        segment_path = os.path.join(frames_dir, segment)
        frame_files = [f for f in os.listdir(segment_path) if f.lower().endswith(supported_extensions)]

        if not frame_files:
            print(f"error: Geen frames gevonden in segment '{segment}'.")
            continue

        total_frames = len(frame_files)
        print(f"Processing segment {seg_idx}/{total_segments}: '{segment}' with {total_frames} frames.")

        for idx, frame_file in enumerate(sorted(frame_files), start=1):
            frame_path = os.path.join(segment_path, frame_file)

            # Genereer een caption voor het frame
            caption = generate_caption(frame_path, model, feature_extractor, tokenizer, device)
            # print(f"Generated caption for {frame_file}: {caption}")  # Optioneel: kan veel output genereren

            # Voeg de gecombineerde data toe aan de lijst
            combined_data.append({
                "video_segment": segment,  # Naam van het video segment
                "frame_number": idx - 1,  # Frames starten meestal bij 0
                "frame_filename": frame_file,
                "caption": caption
            })

            if idx % 100 == 0 or idx == total_frames:
                print(f"Segment '{segment}': Processed {idx}/{total_frames} frames.")

    return combined_data


def save_combined_data(combined_data, output_json_path):

    try:
        with open(output_json_path, 'w', encoding='utf-8') as f:
            json.dump(combined_data, f, indent=4)
        print(f"Combined JSON saved to '{output_json_path}'.")
    except Exception as e:
        print(f"Error saving JSON file: {e}")


def main():
    # Initialiseer het captioning model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model, feature_extractor, tokenizer = initialize_captioning_model()
    model.to(device)

    # Genereer captions voor alle frames
    combined_data = process_all_frames(
        frames_dir=frames_dir,
        model=model,
        feature_extractor=feature_extractor,
        tokenizer=tokenizer,
        device=device
    )

    if not combined_data:
        print("Geen gecombineerde data om op te slaan.")
        return

    # Sla de gecombineerde data op als JSON
    save_combined_data(combined_data, output_json_path)

if __name__ == "__main__":
    main()

## Connect Segments and Captions to Chunks

In [None]:
import json
import os
from pathlib import Path
from typing import List, Dict, Any
import re

def load_json(file_path: str) -> Any:

    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json(data: Any, file_path: str):
  
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4)
    print(f"JSON opgeslagen als '{file_path}'.")

def parse_segment_filename(filename: str) -> Dict[str, Any]:
    pattern = r"segment_(\d+)_(\d+)_(\d+)\.mp4"
    match = re.match(pattern, filename)
    if match:
        return {
            "index": int(match.group(1)),
            "start": float(match.group(2)),
            "end": float(match.group(3))
        }
    else:
        return {}

def find_video_segments(chunks: List[Dict[str, Any]], video_segments_dir: str) -> Dict[tuple, str]:
 
    video_segments = os.listdir(video_segments_dir)
    segment_map = {}
    for segment_file in video_segments:
        parsed = parse_segment_filename(segment_file)
        if not parsed:
            print(f" file '{segment_file}' not same as pattern")
            continue
        key = (parsed['start'], parsed['end'])
        segment_map[key] = os.path.join(video_segments_dir, segment_file)
    return segment_map

def find_frames_for_segment(frames_dir: str, segment_filename: str) -> List[str]:

    segment_frame_dir = os.path.join(frames_dir, segment_filename)
    if not os.path.isdir(segment_frame_dir):
        print(f"error: Frames directory '{segment_frame_dir}' not exist.")
        return []
    frames = [os.path.join(segment_frame_dir, f) for f in os.listdir(segment_frame_dir) 
              if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
    return sorted(frames)  # Sorteer frames op naam

def find_object_segments_for_frame(generated_segments_dir: str, segment_filename: str, frame_filename: str) -> List[str]:
    # Voorbeeld pad:
    # generated_segments/segment_1_0_32.mp4/segment_1_0_32_mp4_frame_000003_segments/segment_1.png
    frame_basename = os.path.splitext(os.path.basename(frame_filename))[0]  # segment_1_0_32_mp4_frame_000003
    segments_folder = os.path.join(generated_segments_dir, segment_filename, f"{frame_basename}_segments")
    if not os.path.isdir(segments_folder):
        print(f"Waarschuwing: Object segments directory '{segments_folder}' bestaat niet.")
        return []
    object_segments = [os.path.join(segments_folder, f) for f in os.listdir(segments_folder) 
                       if f.lower().endswith('.png')]
    return sorted(object_segments)  # Sorteer object segmenten op naam

def load_captions(captions_json_path: str) -> Dict[str, str]:
    captions_data = load_json(captions_json_path)
    caption_map = {}
    for entry in captions_data:
        video_segment = entry.get('video_segment')
        frame_filename = entry.get('frame_filename')
        caption = entry.get('caption', "")
        if video_segment and frame_filename:
            key = f"{video_segment}/{frame_filename}"
            caption_map[key] = caption
    print(f"Loaded captions for {len(caption_map)} frames.")
    return caption_map

def update_chunks_with_segments_and_captions(original_chunks: List[Dict[str, Any]], 
                                segment_map: Dict[tuple, str],
                                frames_dir: str,
                                generated_segments_dir: str,
                                captions_map: Dict[str, str]) -> List[Dict[str, Any]]:

    updated_chunks = []
    for chunk in original_chunks:
        chunk_start = chunk['start']
        chunk_end = chunk['end']
        key = (int(chunk_start), int(chunk_end))
        video_segment_path = segment_map.get(key)
        if not video_segment_path:
            print(f"error: no video segment found for chunk start {chunk_start} and end {chunk_end}.")
            continue  # Of handle anders, afhankelijk van behoeften
        
        segment_filename = os.path.basename(video_segment_path)
        frames = find_frames_for_segment(frames_dir, segment_filename)
        frames_info = []
        for frame_path in frames:
            frame_filename = os.path.basename(frame_path)
            object_segments = find_object_segments_for_frame(generated_segments_dir, segment_filename, frame_filename)
            # Koppel de caption op basis van video_segment en frame_filename
            caption_key = f"{segment_filename}/{frame_filename}"
            caption = captions_map.get(caption_key, "")
            frames_info.append({
                "frame_path": frame_path,
                "object_segments": object_segments,
                "caption": caption
            })
        
        # Voeg video segment en frames info toe aan de chunk
        updated_chunk = {
            "text": chunk['text'],
            "start": chunk['start'],
            "end": chunk['end'],
            "video_segment": video_segment_path,
            "frames": frames_info,
            "words": chunk.get('words', [])
        }
        updated_chunks.append(updated_chunk)
    
    return updated_chunks

def main():
    # Definieer paden
    original_json_path = "processed_json/output_audio_1_chunks.json"  # Originele JSON met chunks
    video_segments_dir = "video_segments/input_video_1"  # Map met video segmenten
    frames_dir = "frames/input_video_1"  # Hoofd map met frames per segment
    generated_segments_dir = "generated_segments"  # Map met object segmenten
    captions_json_path = "combined_frames.json"  # JSON met image captions
    new_json_path = "updated_chunks_with_segments_and_captions.json"  # Nieuwe JSON output

    # Controleer of alle benodigde bestanden en directories bestaan
    if not os.path.exists(original_json_path):
        print(f"Error: Originele JSON bestand '{original_json_path}' bestaat niet.")
        return
    if not os.path.isdir(video_segments_dir):
        print(f"Error: Video segments directory '{video_segments_dir}' bestaat niet.")
        return
    if not os.path.isdir(frames_dir):
        print(f"Error: Frames directory '{frames_dir}' bestaat niet.")
        return
    if not os.path.isdir(generated_segments_dir):
        print(f"Error: Generated segments directory '{generated_segments_dir}' bestaat niet.")
        return
    if not os.path.exists(captions_json_path):
        print(f"Error: Captions JSON bestand '{captions_json_path}' bestaat niet.")
        return

    # Laad originele JSON
    original_data = load_json(original_json_path)
    chunks = original_data.get('chunks', [])
    if not chunks:
        print("Geen chunks gevonden in de originele JSON.")
        return

    print(f"Loaded {len(chunks)} chunks from '{original_json_path}'.")

    # Maak een mapping van (start, end) tijden naar video segment paden
    segment_map = find_video_segments(chunks, video_segments_dir)
    print(f"Found {len(segment_map)} video segments.")

    # Laad captions en maak een mapping
    captions_map = load_captions(captions_json_path)

    # Update chunks met video segmenten, frames, object segmenten en captions
    updated_chunks = update_chunks_with_segments_and_captions(
        chunks, 
        segment_map, 
        frames_dir, 
        generated_segments_dir, 
        captions_map
    )

    print(f"Updated {len(updated_chunks)} chunks with segments and captions.")

    # Maak de nieuwe JSON structuur
    new_data = {
        "chunks": updated_chunks
    }

    # Sla de nieuwe JSON op
    save_json(new_data, new_json_path)
    print(f"new JSON with connected segments and captions saved as '{new_json_path}'.")


main()

## Sentence Transformers Captions Linking to Chunks

In [None]:
!pip install sentence-transformers torch tqdm

In [None]:
import json
import os
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer, util
import torch
from tqdm import tqdm

def load_json(file_path: str) -> Any:
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json(data: Any, file_path: str):
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4)
    print(f"JSON opgeslagen als '{file_path}'.")

def preprocess_chunks(chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed_chunks = []
    for chunk in chunks:
        text = chunk.get('text', "")
        if text:
            # Gebruik zowel start als end tijden als float voor unieke identificatie
            processed_chunks.append({
                "chunk_id": f"{chunk['start']}_{chunk['end']}",
                "text": text,
                "start": chunk['start'],
                "end": chunk['end']
            })
    return processed_chunks

def preprocess_captions(captions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed_captions = []
    for caption in captions:
        text = caption.get('caption', "")
        if text:
            processed_captions.append({
                "video_segment": caption.get('video_segment', ""),
                "frame_filename": caption.get('frame_filename', ""),
                "caption": text
            })
    return processed_captions

def compute_embeddings(model: SentenceTransformer, texts: List[str], batch_size: int = 32) -> torch.Tensor:

    embeddings = model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
    return embeddings

def link_captions_to_chunks(
    chunks: List[Dict[str, Any]],
    captions: List[Dict[str, Any]],
    similarity_threshold: float = 0.3
) -> List[Dict[str, Any]]:

    # Initialiseer het SentenceTransformer model
    model = SentenceTransformer('all-mpnet-base-v2')  # Een model dat langere teksten ondersteunt
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    # Voorbereiden van teksten
    chunk_texts = [chunk['text'] for chunk in chunks]
    caption_texts = [caption['caption'] for caption in captions]

    # Compute embeddings
    print("Computing embeddings for chunks...")
    chunk_embeddings = compute_embeddings(model, chunk_texts).to(device)
    print("Computing embeddings for captions...")
    caption_embeddings = compute_embeddings(model, caption_texts).to(device)

    # Bereken cosine similarity tussen elke caption en alle chunks
    print("Calculating cosine similarities...")
    cosine_similarities = util.cos_sim(caption_embeddings, chunk_embeddings)  # Shape: (num_captions, num_chunks)

    # Voor elke caption, vind de chunk met hoogste similarity
    print("Linking captions to chunks based on similarity...")
    links = []
    for idx, caption in enumerate(tqdm(captions, desc="Linking Captions")):
        sim_scores = cosine_similarities[idx]
        top_result = torch.argmax(sim_scores).item()
        top_score = sim_scores[top_result].item()
        if top_score >= similarity_threshold:
            linked_chunk = chunks[top_result]
            links.append({
                "caption_index": idx,
                "frame_filename": caption['frame_filename'],
                "video_segment": caption['video_segment'],
                "caption": caption['caption'],
                "linked_chunk_id": linked_chunk['chunk_id'],
                "linked_chunk_text": linked_chunk['text'],
                "similarity_score": top_score
            })
        else:
            links.append({
                "caption_index": idx,
                "frame_filename": caption['frame_filename'],
                "video_segment": caption['video_segment'],
                "caption": caption['caption'],
                "linked_chunk_id": None,
                "linked_chunk_text": None,
                "similarity_score": top_score
            })
    return links

def filter_linked_captions(links: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Filter de links om alleen de gekoppelde captions te behouden.
    """
    filtered_links = [link for link in links if link['linked_chunk_id'] is not None]
    print(f"Filtered captions: {len(filtered_links)} out of {len(links)} were successfully linked.")
    return filtered_links

def main():
    # Definieer paden
    chunks_json_path = "processed_json/output_audio_1_chunks.json"        # Originele JSON met chunks
    captions_json_path = "combined_frames.json"              # JSON met image captions
    output_mapping_path = "captions_to_chunks_mapping.json" # Nieuwe JSON output (alle koppelingen)
    filtered_output_path = "filtered_captions_to_chunks_mapping.json" # Nieuwe JSON output (gekoppelde captions)

    # Controleer of alle benodigde bestanden bestaan
    required_files = [chunks_json_path, captions_json_path]
    for file in required_files:
        if not os.path.exists(file):
            print(f"Error: Vereist bestand '{file}' bestaat niet.")
            return

    # Laad de JSON-bestanden
    print("Loading JSON files...")
    chunks_data = load_json(chunks_json_path)
    captions_data = load_json(captions_json_path)

    chunks = chunks_data.get('chunks', [])
    captions = captions_data  # Verondersteld dat combined_frames.json een lijst is

    if not chunks:
        print("Geen chunks gevonden in de hoofd JSON.")
        return
    if not captions:
        print("Geen captions gevonden in de captions JSON.")
        return

    print(f"Loaded {len(chunks)} chunks and {len(captions)} captions.")

    # Voorbereiden van data
    processed_chunks = preprocess_chunks(chunks)
    processed_captions = preprocess_captions(captions)

    # Koppel captions aan chunks via Sentence Transformers
    links = link_captions_to_chunks(processed_chunks, processed_captions, similarity_threshold=0.3)
    print(f"Generated {len(links)} caption links.")

    # Sla de volledige mapping op als een nieuwe JSON
    save_json(links, output_mapping_path)
    print(f"Captions to chunks mapping saved to '{output_mapping_path}'.")

    # Filter de gekoppelde captions
    filtered_links = filter_linked_captions(links)

    # Sla de gefilterde mapping op als een nieuwe JSON
    save_json(filtered_links, filtered_output_path)
    print(f"Filtered captions to chunks mapping saved to '{filtered_output_path}'.")

if __name__ == "__main__":
    main()