### Notebook to generate the text-to-game music dataset
Pipeline:
1. Scrape YouTube Videos with youtube API key
2. Crop 30s videos and check for music (keep videos with >25s of music)
3. Screenshot videos and caption them with kosmos-2
4. Push captions to Llama2-7b to retrieve text prompt


Please ensure you have the following:
1. A hugging face access key from an accurate that has granted access to meta llama repository. Get access here: https://huggingface.co/meta-llama/Llama-2-7b

2. A youtube API key

In [None]:
hg_access = 'FILL IN YOUR ACCESS CODE HERE'
yt_API = 'FILL IN YOUR ACCESS CODE HERE'

#### Setup

Please run the following cell to setup inaSpeechSegmenter

By default, it will be downloaded into the same folder as the notebook

If you would like to change the folder directory, please update the folder when doing the imports in the later cells

In [None]:
import os

def setup_inaSpeechSegmenter(directory:str=''):
    """
    Function to download inaSpeechSegmenter. This command make uses of os to run terminal commands.
    Alternatively, please following the instruction process here: https://github.com/ina-foss/inaSpeechSegmenter

    Args:
        directory (str, optional): directory to the inaSpeechSegmenter. Defaults to ''.
    
    Returns:
        0 for success and 1 for failure
    """
    if not os.path.isdir(directory):
        os.makedirs(directory) 
        print("Setup directory not found, directory created")

    try:
        if directory == '':
            directory = os.getcwd()
        
        os.chdir(directory)
        os.system("git clone https://github.com/ina-foss/inaSpeechSegmenter.git")
        os.chdir("inaSpeechSegmenter")
        os.system("pip install .")
        os.system("python setup.py test")
        return 0
    
    except Exception as e:
        print(f"Failed to download inaSpeechSegmenter, error: {e}")
        return 1        

In [None]:
setup_inaSpeechSegmenter()

#### Generating text-to-game music dataset

In [None]:
#Imports
import sys
import subprocess
import random
import re
import os
import shutil 
import cv2
import torch
import csv
import pandas as pd
from PIL import Image

from googleapiclient.discovery import build
from isodate import parse_duration
from transformers import  AutoProcessor, AutoModelForVision2Seq, AutoTokenizer, AutoModelForCausalLM, GenerationConfig

In [None]:
### Importing inaSpeechSegmenter

#to import, you need to be in the inaSpeechSegmenter folder, PLEASE update the directory respectively
current_path = os.getcwd()
if current_path[-18:] != 'inaSpeechSegmenter':
    os.chdir('inaSpeechSegmenter') #UPDATE YOUR PATH HERE

from inaSpeechSegmenter import Segmenter
from inaSpeechSegmenter.export_funcs import seg2csv, seg2textgrid

##### Youtube Links Fetcher and Downloader

In [None]:
def fetch_videos(api_key:str, search_term:str="no commentary walkthrough", video_duration:list=[240, 3600], limit:int=10):
    """
    Function to retrieve different youtube links from search term

    Args:
        api_key (str): youtube api key
        search_term (str, optional): youtube search term. Defaults to "no commentary walkthrough"
        video_duration (list, optional): accepted videos will be within this duration. Defaults to [240,3600] -> 4 to 30 minutes
        limit (int, optional): maximum num of links. Defaults to 10.

    Returns:
        list of youtube url links
    """
    # Initialize the YouTube API client
    youtube = build('youtube', 'v3', developerKey=api_key)

    video_urls = []
    page_token = None

    while len(video_urls) < limit:
        # Adjust the search limit based on remaining needed videos
        search_limit = min(limit - len(video_urls), 50)  # API max is 50 for a single request

        # Search for videos matching the term with pagination
        search_response = youtube.search().list(
            q=search_term,
            part='id,snippet',
            maxResults=search_limit,
            type='video',
            pageToken=page_token
        ).execute()

        video_ids = [item['id']['videoId'] for item in search_response['items']]

        if not video_ids:
            print("No more videos found, exiting")
            break  # Exit if no more videos are found

        # Fetch details for each video to filter by precise duration
        videos_response = youtube.videos().list(
            part='contentDetails',
            id=','.join(video_ids)
        ).execute()

        for item in videos_response['items']:
            duration = parse_duration(item['contentDetails']['duration']).total_seconds()
            if video_duration[0] <= duration <= video_duration[1]:  # checking video duration
                video_urls.append(f"https://www.youtube.com/watch?v={item['id']}")
                if len(video_urls) >= limit:
                    break

        page_token = search_response.get('nextPageToken')
        if not page_token:
            break  # Exit the loop if there are no more pages to fetch

    return video_urls

In [None]:
def download_extract_clips(url:str, 
                            output_dir:str, 
                            number_of_clips:int=10, 
                            clip_length:int=30):
    """
    Function to download the video and randomly create clips
    Note: Function uses os and subprocesses

    Args:
        url (str): youtube url link
        output_dir (str): folder where the videos will go
        number_of_clips (int, optional): number of clips to cut per video. Defaults to 10.
        clip_length (int, optional): length of each clips in seconds. Defaults to 30.

    Returns:
        clip_paths (list): list of clip filenames else None if there is an error
    
    Raise:
        FileNotFoundError: occurs if the downloaded video was unfound. Directory error.
    """
    try:
        # Ensure the output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # Define preferences
        format_preference = "18"  # yt-dlp resolution

        # Fetch the video title for naming
        command_get_title = ["yt-dlp", "--get-title", url]
        result_title = subprocess.run(command_get_title, capture_output=True, text=True, check=True, encoding='utf-8')

        title = result_title.stdout.strip()
        clean_title = re.sub(r'[^\w\-_\. ]', '_', title)
        output_template = os.path.join(output_dir, f"{clean_title}.%(ext)s")
        
        # Download the video
        print("Downloading video...")
        command_download = ["yt-dlp", "-f", format_preference, "-o", output_template, url]
        subprocess.run(command_download, capture_output=True, text=True, check=True)

        downloaded_filename = os.path.normpath(output_template.replace('%(ext)s', 'mp4'))

        print(f"Downloaded filename: {downloaded_filename}")

        if not os.path.exists(downloaded_filename):
            raise FileNotFoundError(f"Expected downloaded file not found: {downloaded_filename}")

        # Get video duration
        command_duration = ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", downloaded_filename]
        result_duration = subprocess.run(command_duration, capture_output=True, text=True)
        duration_seconds = float(result_duration.stdout.strip())

        # Generate random start times
        clip_paths = []
        for i in range(number_of_clips):
            start_time = random.randint(0, max(int(duration_seconds - clip_length), 0))
            clip_filename = os.path.join(output_dir, f"{clean_title}_clip_{i}.mp4")

            # Extract a 30s clip from the video
            command_extract = ["ffmpeg", "-ss", str(start_time), "-t", str(clip_length), "-i", downloaded_filename, "-c:v", "libx264", "-c:a", "aac", clip_filename]
            subprocess.run(command_extract, check=True, capture_output=True, text=True)
            
            clip_paths.append(clip_filename)

        # Optionally, delete the original video file
        os.remove(downloaded_filename)

        # Return the list of clip filenames
        return clip_paths
    
    except subprocess.CalledProcessError as e:
        print(f"An error occurred during subprocess execution: {e.stderr}.\nSkipping Video")
        return None
    
    except Exception as e:
        print(f"An unexpected error occurred: {e}\nSkipping Video")
        return None

#### Check video audio for music

In [None]:
def check_music_length(media_path:str, csv_path:str='myseg.csv', music_length_threshold:int=25):
    """
    Function to run Segmenter on the edited audio to check how much music is inside

    Args:
        media_path (str): path to video clip
        csv_path (str, optional): path to where csv of music data will be stored. Defaults to myseg.csv
        music_length_threshold (int, optional): threshold for how long music must be in seconds. Defaults to 25s

    Returns:
        media_path: path of the video clip if it passes the threshold, else None
    """
    seg = Segmenter()
    segmentation = seg(media_path)
    seg2csv(segmentation, csv_path)
    
    try:
        # Assuming the first row could be headers that are misinterpreted
        segmentation_df = pd.read_csv(csv_path, sep='\t', names=['labels', 'start', 'stop'], skiprows=1)
    except ValueError:
        print("Error reading the segmentation CSV. Please check the format.")
        return None

    try:
        segmentation_df['start'] = segmentation_df['start'].astype(float)
        segmentation_df['stop'] = segmentation_df['stop'].astype(float)
    except ValueError as e:
        print(f"Error converting start/stop times to float: {e}")
        return None

    total_music_duration = segmentation_df[segmentation_df['labels'] == 'music']['stop'].sum() - \
                           segmentation_df[segmentation_df['labels'] == 'music']['start'].sum()

    #delete video if music lower than threshold
    if total_music_duration < music_length_threshold:
        try:
            os.remove(media_path)
            print(f"Deleted {media_path} due to insufficient music duration.")
            return None
        except OSError as e:
            print(f"Error deleting file {media_path}: {e}")
            return None
    else:
        print(f"{media_path} contains enough music. It will not be deleted.")
        return media_path

#### Screenshot Videos and Caption Screenshots

In [None]:
def capture_screenshots(video_path:str, interval:list=[5,10,15,20,25]):
    """
    Captures screenshots from a video at specified times and returns a dictionary with custom keys for each path.

    Args:
        video_path (str): Path to the video file.
        interval (list): List of times in seconds at which to capture the screenshots. Defaults to [5,10,15,20,25].

    Returns:
        A dictionary with keys like 'path1', 'path2', etc., pointing to the file paths of the captured screenshots.
    """
    
    # Initialize a dictionary to hold the paths of the screenshots with custom keys.
    screenshots_paths = {}

    # Load the video.
    video = cv2.VideoCapture(video_path)

    # Check if video opened successfully.
    if not video.isOpened():
        print("Error: Could not open video.")
        return screenshots_paths

    # Get video FPS (frames per second) to calculate the frame number.
    fps = video.get(cv2.CAP_PROP_FPS)

    # The directory where screenshots will be saved.
    save_directory = "D:\\video_extraction\\video\\360p"

    # Iterate over the specified times, using enumerate to get both index and time.
    for index, time in enumerate(interval, start=1):
        # Calculate the frame number.
        frame_number = int(time * fps)

        # Set video position to the specific frame.
        video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)

        # Read the frame.
        success, frame = video.read()

        # Check if the frame was grabbed successfully.
        if success:
            # Define the file path for the screenshot, including the save directory.
            file_path = f"{save_directory}/screenshot_{time}s.jpg"

            # Save the frame as an image file.
            cv2.imwrite(file_path, frame)

            # Use a custom key for each path.
            key = f"path{index}"

            # Add the key and file path to the dictionary.
            screenshots_paths[key] = file_path
        else:
            print(f"Error: Could not capture screenshot at {time}s")

    # Release the video capture object.
    video.release()

    # Return the dictionary of screenshot paths.
    return screenshots_paths


In [None]:
def caption_image(model:AutoModelForVision2Seq, processor:AutoProcessor, image_path:str, text_input:str="detailed", device:str='cuda'):
    """
    Function to caption screenshots 

    Args:
        model (AutoModelForVision2Seq): model to caption image, by default it shld be Kosmos-2
        processor (AutoProcessor): tokenizer model for model
        image_path (str): path to image
        text_input (str, optional): grounding prompts. Using brief, or detailed will use pre-defined prompts. Defaults to detailed
        device (str, optional): cuda or cpu. Defaults to cuda

    Returns:
        captions (str): image caption
    """    
    #read image
    image_input = Image.open(image_path).convert("RGB")
    
    if text_input.lower() == "brief":
        text_input = "<grounding>An image of"
    elif text_input.lower() == "detailed":
        text_input = "<grounding>Describe this image in detail:"
    else:
        text_input = f"<grounding>{text_input}"

    inputs = processor(text=text_input, images=image_input, return_tensors="pt").to(device)

    #caption image
    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds=None,
        image_embeds_position_mask=inputs["image_embeds_position_mask"],
        use_cache=True,
        max_new_tokens=128,
    )

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    processed_text = processor.post_process_generation(generated_text)

    caption = processed_text[0].replace("Describe this image in detail: ", "")

    return caption



#### Generate Text Prompt

In [None]:
class PromptTemplate:
    system_prompt = None
    user_messages = []
    model_replies = []

    def __init__(self, system_prompt=None):
        self.system_prompt = system_prompt

    def add_user_message(self, message: str, return_prompt=True):
        self.user_messages.append(message)
        if return_prompt:
            return self.build_prompt()

    def add_model_reply(self, reply: str, includes_history=True, return_reply=True):
        reply_ = reply.replace(self.build_prompt(), "") if includes_history else reply
        self.model_replies.append(reply_)
        if len(self.user_messages) != len(self.model_replies):
            raise ValueError(
                "Number of user messages does not equal number of system replies."
            )
        if return_reply:
            return reply_

    def get_user_messages(self, strip=True):
        return [x.strip() for x in self.user_messages] if strip else self.user_messages

    def get_model_replies(self, strip=True):
        return [x.strip() for x in self.model_replies] if strip else self.model_replies

    def clear_chat_history(self):
        self.user_messages.clear()
        self.model_replies.clear()

    def build_prompt(self):
        if self.user_messages == [] and self.model_replies == []:
            return f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>> [/INST]</s>"
        
        elif len(self.user_messages) != len(self.model_replies) + 1:
            raise ValueError(
                "Error: Expected len(user_messages) = len(model_replies) + 1. Add a new user message!"
            )

        if self.system_prompt is not None:
            SYS = f"[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>"
        else:
            SYS = ""

        CONVO = ""
        SYS = "<s>" + SYS
        for i in range(len(self.user_messages) - 1):
            user_message, model_reply = self.user_messages[i], self.model_replies[i]
            conversation_ = f"{user_message} [/INST] {model_reply} </s>"
            if i != 0:
                conversation_ = "[INST] " + conversation_
            CONVO += conversation_

        CONVO += f"[INST] {self.user_messages[-1]} [/INST]"

        return SYS + CONVO

In [None]:
def generate_text_prompt(llama_model:AutoModelForCausalLM, llama_tokenizer:AutoTokenizer, caption_list:list, context_prompt:str='', device:str='cuda'):
    """
    Function to generate text-prompt pair with llama

    Args:
        llama_model (AutoModelForCausalLM): llama model
        llama_tokenizer (AutoTokenizer): llama tokenizer
        caption_list (list): list of captions from screenshots
        context_prompt (str): llama context prompt, if '' the default prompt will be used. Defaults to ''
    """

    if context_prompt == '':
        context_prompt=f"""context={caption_list[0]},{caption_list[1]},{caption_list[2]},{caption_list[3]},{caption_list[4]}. 
        imagine you are writing a description of a video. You are capturing the emotions caused by the environment and atmosphere. You should write a 30 word paragraph that takes the 5 context and summarise all of them into that one paragraph. Talk about the environment, the vibes and the emotions. Immediately starts describing and do not mention the source. I will give you example prompts, follow their formatting but not the content. Give me just the paragraph and nothing else. 
        example prompts:
        A small rural village. The atmosphere is peaceful, with citizens doing their daily tasks. However, there is an underlying tension in the air, like something is about to go down. The overall emotion of the video is one of suspense and intrigue."""

    # Initialize the prompt generator with the given context
    prompt = PromptTemplate(context_prompt)
    true_prompt = prompt.build_prompt()

    config = GenerationConfig(
        max_new_tokens=1024,
        do_sample= True,
        top_k= 10,
        num_return_sequences= 1,
        return_full_text= False,
        temperature= 0.1,
    )

    encoded_input = llama_tokenizer.encode(true_prompt, return_tensors='pt', add_special_tokens=False).to(device)
    results = llama_model.generate(encoded_input, generation_config=config)
    decoded_output = llama_tokenizer.decode(results[0], skip_special_tokens=True)
    response = decoded_output.split("[/INST]")[-1].strip()
    #The response might not be consistent with the format, please explore and edit accordingly

    return response

#### Overarching function

In [None]:
def generate_textmusic_dataset(yt_api_key:str,
                               kosmos_model:AutoModelForVision2Seq,
                               kosmos_processor:AutoProcessor,
                               llama_model:AutoModelForCausalLM,
                               llama_tokenizer:AutoTokenizer,
                               music_output_dir:str, 
                               dataset_csv_dir:str,
                               music_csv_path:str='myseg.csv',
                               search_term:str="no commentary walkthrough", 
                               video_duration:list=[240, 3600], 
                               video_limit:int=10,
                               clips_per_video:int=10, 
                               clips_length:int=30,
                               music_length_threshold:int=25,
                               screenshot_interval:list=[5,10,15,20,25],
                               kosmos_text_input:str="detailed", 
                               llama_context_prompt:str='',
                               device:str='cuda'
                               ):
    """
    Function to generate dataset

    Args:
        yt_api_key (str): yt api key
        kosmos_model (AutoModelForVision2Seq): kosmos model
        kosmos_processor (AutoProcessor): kosmos tokenizer
        llama_model (AutoModelForCausalLM): llama model
        llama_tokenizer (AutoTokenizer): llama tokenizer
        
        music_output_dir (str): path to the dataset folder
        dataset_csv_dir (str): path to the csv containing the filepath of video and the text prompt
        music_csv_path (str, optional): music csv path for music split in each clip. Defaults to 'myseg.csv'.
        
        search_term (str, optional): youtube search time. Defaults to "no commentary walkthrough".
        video_duration (list, optional): accepted videos will be within this duration. Defaults to [240, 3600].
        video_limit (int, optional): maximum num of videos to be processed. Defaults to 10.
        
        clips_per_video (int, optional): number of clips to cut per video. Defaults to 10.
        clips_length (int, optional): length of each clips in seconds. Defaults to 30.
        
        music_length_threshold (int, optional): threshold for how long music must be in seconds. Defaults to 25.
        screenshot_interval (list, optional): List of times in seconds at which to capture the screenshots. Defaults to [5,10,15,20,25].
        
        kosmos_text_input (str, optional): grounding prompts. Using brief, or detailed will use pre-defined prompts. Defaults to detailed
        llama_context_prompt (str, optional): llama context prompt, if '' the default prompt will be used. Defaults to ''. Defaults to ''.
        device (str, optional): cuda or cpu. Defaults to cuda
    """
    

    
    base_dir = os.path.dirname(dataset_csv_dir)
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    
    with open(dataset_csv_dir, mode='a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        # Check if the file is empty to decide on writing headers
        file.seek(0, os.SEEK_END)
        if file.tell() == 0:
            writer.writerow(['Video Path', 'Caption'])

        #extract youtube links
        video_urls = fetch_videos(yt_api_key, search_term, video_duration, video_limit)
        for url in video_urls:
            #download and crop videos
            clip_paths_list = download_extract_clips(url, music_output_dir, clips_per_video, clips_length)
            for clip_path in clip_paths_list:
                #extract videos and check music length
                clip_path = check_music_length(clip_path, music_csv_path, music_length_threshold)
                if clip_path is not None:
                    #extract screenshots
                    screenshot_paths = capture_screenshots(clip_path, screenshot_interval)
                    caption_list = []
                    for screenshots in screenshot_paths.values():
                        #caption images
                        caption_list.append(caption_image(kosmos_model, kosmos_processor, screenshots, kosmos_text_input, device))
                        os.remove(screenshots)

                    #text prompt
                    text_prompt = generate_text_prompt(llama_model, llama_tokenizer, caption_list, llama_context_prompt, device)
                    writer.writerow([clip_path, text_prompt])

        file.close()

### Utilising

In [None]:
kosmos = "microsoft/kosmos-2-patch14-224"
kosmos_model = AutoModelForVision2Seq.from_pretrained(kosmos).to("cuda")
kosmos_processor = AutoProcessor.from_pretrained(kosmos)

device = "cuda:0" if torch.cuda.is_available() else "cpu" 
llama_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hg_key, torch_dtype=torch.bfloat16, device_map="auto")
llama_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hg_key)
print(f"Model running on {device}")

In [None]:
generate_textmusic_dataset(yt_API, 
                           kosmos_model, 
                           kosmos_processor,
                           llama_model,
                           llama_tokenizer,
                           '../../input/text-music-dataset/',
                           '../../input/text-music-dataset/data.csv')