In [1]:
# Load libraries and constants
import os
from glob import glob
import re
from datetime import datetime
import urllib
from googleapiclient.discovery import build
import pandas as pd

DEFAULT_SECRETS_DIR = os.path.join('../configs', 'secrets.yaml')
PUBLISHED_AFTER = datetime.strptime("1/1/2017", '%d/%m/%Y').isoformat() + 'Z'
MAX_PAGE_SIZE = 50 # no benefit if this number is anything other than 50
VIDEO_DOWNLOAD_DIRECTORY="../data/youtube_dataset/videos"
VIDEO_DATA_PATH="../data/youtube_dataset/video_data.json"
TRANSCRIPT_DOWNLOAD_DIRECTORY = r"C:\Users\emaid\Desktop\guiagents\data\youtube_dataset\transcripts"
CACHE_PATH = "../data/cache"
GOOGLE_API_CLIENT_SECRETS_FILE = "../configs/client_secret_google_api.json"
YOLO_MODEL_PATH = r"C:\\Users\\emaid\\Desktop\\guiagents\\data\\yolo_element_detector\\training\\train11\\weights\\best.pt"
YOLO_TRACKING_FRAME_STEP = 3 # every how many frames apply yolo tracking
YOLO_CONFIDENCE_THRESH = 0.5
YOLO_RESULTS_DIR = "../data/youtube_dataset/mouse_tracking"
TRAIN_DATA_PATH = r"C:\Users\emaid\Desktop\guiagents\data\youtube_dataset\video_data_train.json"
VAL_DATA_PATH = r"C:\Users\emaid\Desktop\guiagents\data\youtube_dataset\video_data_val.json"
TEST_DATA_PATH = r"C:\Users\emaid\Desktop\guiagents\data\youtube_dataset\video_data_test.json"

def load_video_data(data_path=VIDEO_DATA_PATH):
    if os.path.exists(VIDEO_DATA_PATH):
        vid_df = pd.read_json(VIDEO_DATA_PATH, orient='index')
        vid_df.index.name = 'id'
        vid_df.publish_date = pd.to_datetime(vid_df.publish_date)
        return vid_df
    else:
        return None

def store_video_data(vid_df: pd.DataFrame, data_path=VIDEO_DATA_PATH):
    vid_df.publish_date = vid_df.publish_date.apply(lambda x : x.isoformat())
    vid_df.to_json(data_path, indent=4, orient="index")

def parse_video_file(video_path):
    match = re.match(r"\[(.+)\]_(.*)", os.path.basename(video_path))
    return {
        'id': match.group(1),
        'video_name': match.group(2)
    }

#def list_all_video_files(video_download_dir=VIDEO_DOWNLOAD_DIRECTORY):
    #downloaded_ids = glob(os.path.join(video_download_dir, "*.mp4"))
    #downloaded_ids = set(map(lambda x: x['id'], downloaded_ids))

In [74]:
# load video_data if it already exists
import pandas as pd
from IPython.display import display

vid_df = load_video_data(VIDEO_DATA_PATH)
if not vid_df is None:
    display(vid_df.head(5))

Unnamed: 0_level_0,publish_date,video_title,video_description,channel_title,video_tags,view_count,like_count,comments_count,channel_subs_count,channel_view_count,channel_vid_count,channel_id,query
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
S-nHYzK-BVg,2017-09-25 13:12:22+00:00,Beginner's Guide to Microsoft Word,"If you like this video, here's my entire playl...",Technology for Teachers and Students,"[microsoft word, word tutorial, using ms word,...",6973801,88277.0,1891.0,1480000,141834118,522,UCYUPLUCkMiUgiyVuluCc7tQ,windows tutorial
DHq3bqowzW0,2018-08-31 01:32:47+00:00,How to Clean C Drive In Windows 10 (Make Your...,"This video shows you, How to Clean C Drive (Lo...",Geeks Tutorial,"[how to clean local disk c windows 10, how to ...",9910893,202666.0,8536.0,558000,84677874,109,UCU1K6P1M8hq-TBnDlHQtO7A,windows tutorial
EJHKuwBhdB4,2024-02-07 14:45:00+00:00,How to Create a Windows 10 Installation USB wi...,"In this tutorial video, I'll show you how to c...",Memory's Tech Tips,"[how to create a windows 10 installation usb, ...",29995,354.0,38.0,13100,3192330,394,UCpFxsy-mzKIIX14aOH-veXg,windows tutorial
fFe3iESppag,2022-04-13 22:32:33+00:00,How to Clean C Drive In Windows 11 (Make Your ...,"This video shows you, How to Clean C Drive (Lo...",Geeks Tutorial,"[how to clean local disk c windows 11, how to ...",440892,7991.0,291.0,558000,84677874,109,UCU1K6P1M8hq-TBnDlHQtO7A,windows tutorial
ttiA0zRbzko,2018-05-05 15:19:14+00:00,How to Speed Up Your Windows 10 Performance! (...,"This video shows you, How to speed up any Wind...",Geeks Tutorial,"[how to speed up windows 10, speed up windows ...",4112471,85921.0,4556.0,558000,84677874,109,UCU1K6P1M8hq-TBnDlHQtO7A,windows tutorial


## Search youtube 

In [56]:
import yaml

queries = [
  "windows tutorial",
  "linux tutorial",
  "macos tutorial",
  "pandas python tutorial",
  "how to create an account on tutorial",
  "how to download tutorial computer",
  "PowerPoint presentation tutorial",
  "Git and GitHub tutorial",
  "PowerPoint presentation tutorial",
  "Building a website with WordPress tutorial",
  "Creating macros in Excel tutorial",
  "CAD design with AutoCAD tutorial",
  "Animating with Adobe Animate tutorial",
  "how to install C++ tutorial",
  "how to install java tutorial",
  "GIMP photo manipulation tutorial",
  "Creating music with FL Studio tutorial",
  "Creating animations in Maya tutorial",
  "Configuring a VPN tutorial",
  "Audacity tutorial",
  "Learning to code with Scratch tutorial",
  "zoom online meetings tutorial",
  "how to use zoom tutorial"
]

with open(DEFAULT_SECRETS_DIR, 'r') as secrets_file:
    secrets = yaml.safe_load(secrets_file)

yt_api_key = secrets['youtube_api_key']
videos = []

with build('YouTube', "v3", developerKey=yt_api_key) as yt:
    for query in queries:

        if query in list(vid_df['query'].unique()): # check if query already used
            continue

        search_result_list = yt.search().list(
            part='snippet',
            publishedAfter=PUBLISHED_AFTER,
            order='relevance',
            q=query,
            type='video',
            videoCategoryId='26',
            videoCaption="closedCaption",
            maxResults=MAX_PAGE_SIZE,
            relevanceLanguage='en',
        ).execute()

        video_ids = [search_result['id']['videoId'] for search_result in search_result_list['items']]
        video_results = yt.videos().list(part=['id', 'snippet', 'statistics'], id=video_ids).execute()
        video_results_dict = {video_data['id']:video_data for video_data in video_results['items']}

        channel_ids = [video['snippet']['channelId'] for video in video_results['items']]
        channel_results = yt.channels().list(part=['statistics'], id=channel_ids).execute()
        channel_results_dict = {channel_data['id']:channel_data for channel_data in channel_results['items']}

        for video_id, channel_id in zip(video_ids, channel_ids):
            video_data = video_results_dict[video_id]
            channel_data = channel_results_dict[channel_id]
            videos.append(
                {
                    "id": video_data['id'],
                    "publish_date": video_data['snippet'].get('publishedAt', None),
                    "video_title": video_data['snippet'].get('title', None),
                    "video_description": video_data['snippet'].get('description', None),
                    "channel_title": video_data['snippet'].get('channelTitle', None),
                    "video_tags": video_data['snippet'].get('tags', None),
                    "view_count" : video_data['statistics'].get('viewCount', None),
                    "like_count" : video_data['statistics'].get('likeCount', None),
                    "comments_count": video_data['statistics'].get('commentCount', None),
                    "channel_subs_count": channel_data['statistics'].get('subscriberCount', None),
                    "channel_view_count": channel_data['statistics'].get('viewCount', None),
                    "channel_vid_count": channel_data['statistics'].get('videoCount', None),
                    "channel_id": channel_id,
                    "query": query
                }
            )

In [75]:
import pandas as pd

# saving video data
vid_df_new = pd.DataFrame(videos).drop_duplicates(subset='id', keep='first').set_index('id')
vid_df_new.publish_date = pd.to_datetime(vid_df_new.publish_date)
if not vid_df is None:
    vid_df_new = vid_df_new.loc[~vid_df_new.index.isin(vid_df.index)]
    if(len(vid_df_new) > 0):
        vid_df = pd.concat([vid_df, vid_df_new])
    print(f"{len(vid_df_new)} new videos have been found on youtube.")
    print(f"{len(vid_df)} is the new number of videos in dataset.")

store_video_data(vid_df, VIDEO_DATA_PATH)

0 new videos have been found on youtube.
1053 is the new number of videos in dataset.


## Download the videos

In [76]:
# load video_data if starting notebook from here
vid_df = load_video_data(VIDEO_DATA_PATH)

In [77]:
# Monkey patch to make pytube work again
# pytube is occisionally broken and needs monkey patching
from pytube import YouTube
from pytube.innertube import InnerTube

def bypass_age_gate(self):
        """Attempt to update the vid_info by bypassing the age gate."""
        clients = [
            'ANDROID_EMBED', 'IOS', 'ANDROID', 'WEB_EMBED',
            'ANDROID_EMBED', 'IOS_EMBED', 'WEB_MUSIC',
            'IOS_MUSIC', 'WEB_CREATOR', 'ANDROID_CREATOR',
            'IOS_CREATOR', 'MWEB', 'TV_EMBED', 'WEB'
        ]
        print("Clients List:\n\n", clients,"\n")
        success_client = None
        try:
            for client in clients:
                innertube = InnerTube(
                    client=client,
                    use_oauth=self.use_oauth,
                    allow_cache=self.allow_oauth_cache
                )
                innertube_response = innertube.player(self.video_id)

                playability_status = innertube_response['playabilityStatus'].get('status', None)

                # Print the status of each client
                print(f"Client: {client}, Status: {playability_status}")

                # If the video is accessible, update _vid_info and exit the loop
                if playability_status != 'UNPLAYABLE':
                    self._vid_info = innertube_response
                    success_client = client
                    print(f"Chosen client: {client}")
                    break

        except Exception as e:
            print(f"Error: {e}")

        finally:
            if not success_client:
                print("No successful client found. Performing generic action...")
                # Perform generic action here

YouTube.bypass_age_gate = bypass_age_gate

In [None]:
import re
from glob import glob
from tqdm import tqdm
from pytube import YouTube

# retrieve all downloaded ids
downloaded_ids = glob(os.path.join(VIDEO_DOWNLOAD_DIRECTORY, "*.mp4"))
downloaded_ids = set(map(lambda x: parse_video_file(x)['id'], downloaded_ids))

for vid_id in tqdm(vid_df.index):
    # do not re-download them
    if vid_id in downloaded_ids:
        continue
    try:
        video_handle = YouTube(f"https://www.youtube.com/watch?v={vid_id}", use_oauth=True, allow_oauth_cache=True)

        video_stream = video_handle.streams.filter(
            progressive=True,
            file_extension='mp4',
            resolution='720p'
        ).first()

        if video_stream:
            video_stream.download(filename_prefix=f"[{vid_id}]_", output_path=VIDEO_DOWNLOAD_DIRECTORY, skip_existing=True)
        else:
            continue # TODO logging
    except Exception as e:
        print(e)

### Extend the video data with video details

In [88]:
import cv2
from tqdm import tqdm
# Extend data with video details
# I wanted to the properties of exactly the videos I have downloaded
# this is why i did not use the Youtube API for this

vid_df = load_video_data(VIDEO_DATA_PATH)

downloaded_paths = glob(os.path.join(VIDEO_DOWNLOAD_DIRECTORY, "*.mp4"))
downloaded_paths = dict(map(lambda x: (parse_video_file(x)['id'], x), downloaded_paths))

vid_details = []
for vid_id in tqdm(vid_df.index):
    vid_path = downloaded_paths.get(vid_id, None)
    if vid_path:
        vid_cap = cv2.VideoCapture(vid_path)
        vid_id = parse_video_file(vid_path)['id']
        if (vid_cap.isOpened()):
            fps = vid_cap.get(cv2.CAP_PROP_FPS)
            fr_cnt = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
            width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            vid_details.append({
                "id": vid_id,
                "video_available": True,
                "fps": fps,
                "frame_count": fr_cnt,
                "frame_width": width,
                "frame_height": height
            })
        else:
            vid_details.append({
                "id": vid_id,
                "video_available": False,
                "fps": None,
                "frame_count": None,
                "frame_width": None,
                "frame_height": None
            })
        vid_cap.release()
    else:
        vid_details.append({
            "id": vid_id,
            "video_available": False,
            "fps": None,
            "frame_count": None,
            "frame_width": None,
            "frame_height": None
        })

100%|██████████| 1053/1053 [00:07<00:00, 141.89it/s]


In [127]:
vid_details_df = pd.DataFrame(vid_details)
vid_details_df.set_index('id', inplace=True)
print(f"{sum(~vid_details_df['video_available'])} videos were unavailable")
vid_df = vid_df.drop(columns=vid_details_df.columns.intersection(vid_df.columns))
vid_df = pd.merge(vid_df, vid_details_df, left_index=True, right_index=True)
store_video_data(vid_df, VIDEO_DATA_PATH)

40 videos were unavailable


## Download transcripts (WIP)

In [2]:
vid_df = load_video_data(VIDEO_DATA_PATH)
downloaded_vid_paths = glob(os.path.join(VIDEO_DOWNLOAD_DIRECTORY, "*.mp4"))
downloaded_vid_paths = dict(map(lambda x: (parse_video_file(x)['id'], x), downloaded_vid_paths))

downloaded_transcript_paths = glob(os.path.join(TRANSCRIPT_DOWNLOAD_DIRECTORY, "*.json"))
get_id_from_transcript_filename = lambda x : os.path.splitext(os.path.basename(x))[0]
downloaded_transcript_paths = dict(map(lambda x: (get_id_from_transcript_filename(x), x), downloaded_transcript_paths))

NameError: name 'load_video_data' is not defined

In [None]:
from yt_dlp import YoutubeDL
from urllib.request import urlopen
from tqdm import tqdm
import json

from yt_dlp import YoutubeDL
from urllib.request import urlopen
from tqdm import tqdm
import json

yt_opts = {
    "skip_download": True,
    #"cookiefile": r"C:\Users\emaid\Desktop\guiagents\data\youtube.com_cookies.txt",
    "cookiesfrombrowser": ('chrome', ),
    #"listsubtitles": True,
    "writeautosubtitles": True,
    "subtitlesformat": 'json3',
    "subtitlelangs": ['all'],
    "nocheckcertificate": True,
    "quiet": True,
    
}

def restructure_caption(caption_dict):
    segs = caption_dict.get('segs', [])
    caption_dict['text'] = "".join(map(lambda x: list(x.values())[0], segs))
    return caption_dict

with YoutubeDL(yt_opts) as yt: # if it doesnt auto load cookies then we could break this into yt = You.. and yt.close() when the exception happens
    for vid_id in tqdm(vid_df.index):
        if vid_id in downloaded_vid_paths.keys(): # only videos we have donwloaded so far make sense
            if vid_id in downloaded_transcript_paths.keys(): # no need to download twice
                continue
            transcript = dict()
            info = yt.extract_info(f"https://www.youtube.com/watch?v={vid_id}")
            en_sub = info['subtitles'].get('en', None)
            langs = info['subtitles'].keys()
            langs = list(filter(lambda x: x=='en' or x.startswith('en-'), langs))
            if len(langs) > 0: # if there is an english subtitle
                en_sub = info['subtitles'][langs[0]]
                transcript['generated'] = False
            else:
                en_sub = info['automatic_captions'].get('en', None)
                if en_sub:
                    transcript['generated'] = True
                else:
                    print(f'No automatic english captions found for {vid_id}!')
                    continue

            json_encodings = list(filter(lambda x: x['ext'] == 'json3', en_sub))
            if len(json_encodings) > 0:
                url = json_encodings[0]['url']
                url_response = urlopen(url)
                charset = url_response.info().get_content_charset()
                ret_obj = json.loads(url_response.read().decode(charset)) # get json from response
                if 'events' in ret_obj.keys():
                    transcript['transcript'] = list(map(restructure_caption, ret_obj['events']))
                else:
                    print(f'Could not retreive events for {vid_id}!')
                    continue
                
                with open(os.path.join(TRANSCRIPT_DOWNLOAD_DIRECTORY, f"{vid_id}.json"), "w") as transcript_file:
                    json.dump(transcript, transcript_file, indent=4)
            else:
                print(f'{vid_id} is missing the json encoding!')

In [12]:
# This is the old code for donwloading code currently doesnt work
# if the library introduces a fix then it can be used again
#from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound, NoTranscriptAvailable, TranscriptsDisabled
#import traceback
#from tqdm import tqdm
#
#transcripts = dict()
#
#for vid_id in tqdm(vid_df.index[:3]):
#    try:
#        transcript_list = YouTubeTranscriptApi.list_transcripts(vid_id)
#        ts_handle = transcript_list.find_transcript(['en', 'en-GB', 'en-US', 'de', 'fr','es'])
#        transcript = ts_handle.fetch()
#        transcripts[vid_id] = {'is_generated': ts_handle.is_generated, 'availability': "available", 'message':"", 'transcript': transcript} 
#    except NoTranscriptFound as e:
#        transcripts[vid_id] = {'is_generated': None, 'availability': "not_found", 'message': traceback.format_exc(), 'transcript': None} 
#    except NoTranscriptAvailable as e:
#        transcripts[vid_id] = {'is_generated': None, 'availability': "not_available", 'message': traceback.format_exc(), 'transcript': None} 
#    except TranscriptsDisabled as e:
#        transcripts[vid_id] = {'is_generated': None, 'availability': "disabled", 'message': traceback.format_exc(), 'transcript': None} 

100%|██████████| 3/3 [00:02<00:00,  1.34it/s]


## Extract images for data cleaning classifier

In [97]:
import cv2
from glob import glob
import os
from tqdm import tqdm
import numpy as np
# Get the duration of all videos

number_of_datapoints = 10000
vid_df = load_video_data(VIDEO_DATA_PATH)
vid_df = vid_df[vid_df['video_available']]
vid_df.sort_index(inplace=True)

# create interval ends
ends = vid_df['frame_count'].astype(int).to_numpy().cumsum()

# create interval starts
starts = np.roll(ends, 1)
starts[0] = 0

# choose random frames from all possible frames
total_frames = sum(vid_df['frame_count'])
choices = np.sort(np.random.randint(0, total_frames, number_of_datapoints))

# create interval matrix
interval_matches = ((choices[np.newaxis, :] >= starts[:, np.newaxis]) & (choices[np.newaxis, :] < ends[:, np.newaxis]))

# turn it into da pandas dataframe
interval_matches = pd.DataFrame(interval_matches, index=vid_df.index, columns=choices)

# Melt the columns into one
interval_matches = interval_matches.melt(ignore_index=False, value_name="is_in", var_name='frame').reset_index()

# Keep only the frames that where in the correct interval
interval_matches = interval_matches.query('is_in').drop(columns='is_in')

starts_df = vid_df['frame_count'].shift(1, fill_value=0).cumsum().astype(int).reset_index()
interval_matches = interval_matches.merge(starts_df, on="id")
frames_to_extract = pd.DataFrame({"id":interval_matches['id'], "frame": interval_matches['frame'] - interval_matches['frame_count']})
frames_to_extract

Unnamed: 0,id,frame
0,-1N0L-FDWCs,214
1,-1N0L-FDWCs,311
2,-1N0L-FDWCs,934
3,-1N0L-FDWCs,3710
4,-1N0L-FDWCs,7107
...,...,...
9995,zkiEtOay5ZA,1781
9996,zkiEtOay5ZA,2184
9997,zkiEtOay5ZA,3010
9998,zpwgKLGQDQg,2355


In [79]:
frame_num = 5000
v = (vid_df['frame_count'] - 5).cumsum().shift(1, fill_value=0)
i = v.searchsorted(frame_num, side='right') - 1
#frame_num - v.iloc[i]
vid_df.index[i]
v

id
-1N0L-FDWCs           0.0
-Ee1zsYFPgU      113587.0
-F5TrmNdDLo      115993.0
-QwrEVbyZkQ      123943.0
-dcwIpH6GVs      129154.0
                  ...    
zkJkRCZ4_9o    20715686.0
zkb1sZJNdyw    20734810.0
zkiEtOay5ZA    20786850.0
zpwgKLGQDQg    20790149.0
zrZjecLrMWc    20797005.0
Name: frame_count, Length: 1013, dtype: float64

In [None]:
# TODO
# The above cell creates a dataframe with two rows one being the id of a video and the other all of the frames to be extracted from the video
# probably just to write a VideoCapture iterator that gets the needed frames saves them

## Running cursor tracker

In [15]:
from ultralytics import YOLO
import torch as pt
from tqdm import tqdm
import cv2

yolo_model = YOLO(YOLO_MODEL_PATH)

vid_df = load_video_data(VIDEO_DATA_PATH)

downloaded_paths = glob(os.path.join(VIDEO_DOWNLOAD_DIRECTORY, "*.mp4"))
downloaded_paths = dict(map(lambda x: (parse_video_file(x)['id'], x), downloaded_paths))

tracked_vid_ids = glob(os.path.join(YOLO_RESULTS_DIR, "*.csv"))
tracked_vid_ids = set(map(lambda x: os.path.splitext(os.path.basename(x))[0], tracked_vid_ids)) # remove path and csv just get ids

for vid_id in tqdm(vid_df.index):

    try:
        vid_path = downloaded_paths[vid_id]
    except KeyError:
        # TODO log file not found
        continue

    if not vid_id in tracked_vid_ids:

        mouse_positions = []

        # load video
        cap = cv2.VideoCapture(vid_path)
        ret = True
        i = 0
        while ret:
            ret, img = cap.read()
            if ret and (i % YOLO_TRACKING_FRAME_STEP == 0):
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                result = yolo_model.predict(img, verbose=False)[0]
                time = cap.get(cv2.CAP_PROP_POS_MSEC)

                if(sum(result.boxes.conf > YOLO_CONFIDENCE_THRESH) >= 1):
                    x, y, w, h = result.boxes.xywh[pt.argmax(result.boxes.conf)].cpu().tolist()
                    mouse_positions.append({
                        'x': x,
                        'y': y,
                        'width': w,
                        'heigth':h,
                        'frame': i,
                        'conf': pt.max(result.boxes.conf).item()
                    })
            i += 1
        cap.release()

        if len(mouse_positions) > 0:
            mouse_df = pd.DataFrame(mouse_positions).set_index("frame")
            mouse_df.to_csv(os.path.join(YOLO_RESULTS_DIR, f"{vid_id}.csv"), float_format='%.3f')

100%|██████████| 1053/1053 [26:08:51<00:00, 89.39s/it]    


In [2]:
#YOLO_RESULTS_DIR = "../data/youtube_dataset/mouse_tracking"

for mt_path in glob(os.path.join(YOLO_RESULTS_DIR, '*.csv')):
    vid_id = os.path.splitext(os.path.basename(mt_path))[0]
    mt = pd.read_csv(mt_path)



## Train Validation and Test split

In [2]:
vid_df = load_video_data(VIDEO_DATA_PATH)

In [3]:
import numpy as np
np.random.seed(1234)
ids = vid_df.index.to_numpy().copy()
np.random.shuffle(ids)
splits = np.array([0.7, 0.15, 0.15])
splits = np.round(np.cumsum(splits) * len(ids)).astype(int)
train_ids, val_ids, test_ids, _  = np.split(ids, splits)
train_vid_df = vid_df.loc[train_ids]
val_vid_df = vid_df.loc[val_ids]
test_vid_df = vid_df.loc[test_ids]

In [4]:
store_video_data(train_vid_df, data_path=TRAIN_DATA_PATH)
store_video_data(val_vid_df, data_path=VAL_DATA_PATH)
store_video_data(test_vid_df, data_path=TEST_DATA_PATH)

# Cleaning Data

In [None]:
from tqdm import tqdm
from functools import reduce

sequences = []
for mt_path in tqdm(glob(os.path.join(YOLO_RESULTS_DIR, '*.csv'))):
    vid_id = os.path.splitext(os.path.basename(mt_path))[0]
    mt = pd.read_csv(mt_path)
    mt = mt.query("conf>0.85")
    if(len(mt)<150):
        continue
    frames = np.arange(0, np.max(mt['frame'])+1, 3, dtype=int)
    # moving
    mouse_moving = np.zeros(len(frames))
    mouse_moving[mt[np.sqrt(mt['x']**2 + mt['y']**2) > 2*mt['heigth']]['frame']//3] = 1
    mouse_moving_rolling = np.convolve(mouse_moving, np.ones(100)/100, 'same') - 0.5

    # is present
    mouse_present = np.zeros(len(frames))
    mouse_present[mt['frame']//3] = 1
    mouse_present_rolling = np.convolve(mouse_present, np.ones(100)/100, 'same') - 0.85
    mouse =  np.minimum(mouse_present_rolling, mouse_moving_rolling)
    ispos = np.concatenate(([0], (mouse > 0).view(np.int8), [0]))
    absdiff = np.abs(np.diff(ispos))
    clean_intervals = np.where(absdiff == 1)[0]*3
    clean_intervals = clean_intervals.reshape(-1, 2)
    big_enough = (np.diff(clean_intervals, axis=1) >= 150)
    clean_intervals = clean_intervals[big_enough.flatten()]
    if(len(clean_intervals) == 0):
        continue
    def interval_to_seqs(interval):
        steps = np.arange(interval[0], interval[1] - 150, 50)
        intervals = np.stack((steps, steps + 150), axis=1)
        return intervals
    seqs = list(map(interval_to_seqs, clean_intervals))[0]
    if(len(seqs) > 0):
        out = []
        for interval in seqs:
            mt_interval = mt.query("frame >= @interval[0] and frame <= @interval[1]")
            if(mt.query("frame >= @interval[0] and frame < @interval[1]").empty):
                print('HI')
                print(interval)
            out.append((vid_id, interval, mt_interval))
        sequences.extend(out)