<a href="https://colab.research.google.com/github/WaliSiddiqui1/BPM/blob/main/Feature_extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import json
import cv2
import torch
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from datetime import datetime
from google.colab import drive
import re
import time

!pip install -q webvtt-py
!pip install -q transformers
!pip install -q rouge
!pip install -q nltk
!pip install -q yolov5

!pip install -q huggingface-hub>=0.30.0 --upgrade

import nltk
nltk.download('punkt')

import webvtt
drive.mount('/content/drive')

BASE_DIR = '/content/drive/MyDrive/NSVA_Results/'
RAW_VIDEOS_DIR = os.path.join(BASE_DIR, 'raw_videos')
VTT_CAPTIONS_DIR = os.path.join(BASE_DIR, 'video_captions')
FEATURES_DIR = os.path.join(BASE_DIR, 'features')
ANNOTATIONS_DIR = os.path.join(BASE_DIR, 'annotations')
CHECKPOINTS_DIR = os.path.join(BASE_DIR, 'checkpoints')
RESULTS_DIR = os.path.join(BASE_DIR, 'results')

for directory in [RAW_VIDEOS_DIR, VTT_CAPTIONS_DIR, FEATURES_DIR, ANNOTATIONS_DIR,
                 CHECKPOINTS_DIR, RESULTS_DIR]:
    os.makedirs(directory, exist_ok=True)

for feature_type in ['timesformer', 'ball', 'player', 'basket', 'court']:
    os.makedirs(os.path.join(FEATURES_DIR, feature_type), exist_ok=True)

METADATA_DIR = os.path.join(BASE_DIR, 'metadata')
os.makedirs(METADATA_DIR, exist_ok=True)

if os.path.exists(RAW_VIDEOS_DIR):
    video_files = [f for f in os.listdir(RAW_VIDEOS_DIR) if f.endswith(('.mp4', '.avi', '.mkv'))]
    print(f"Found {len(video_files)} video files")


if os.path.exists(VTT_CAPTIONS_DIR):
    vtt_files = [f for f in os.listdir(VTT_CAPTIONS_DIR) if f.endswith('.en.vtt')]
    print(f"Found {len(vtt_files)} VTT files")

def create_file_mapping():
    vtt_files = [f for f in os.listdir(VTT_CAPTIONS_DIR) if f.endswith('.en.vtt')]
    video_files = [f for f in os.listdir(RAW_VIDEOS_DIR) if f.endswith(('.mp4', '.avi', '.mkv'))]

    mapping = {}

    for vtt_file in vtt_files:
        vtt_base = vtt_file[:-7]

        for video_file in video_files:
            video_base = os.path.splitext(video_file)[0]

            if vtt_base == video_base:
                mapping[vtt_file] = video_file
                break

    unmatched_vtts = [vtt for vtt in vtt_files if vtt not in mapping]
    if unmatched_vtts:

        for vtt_file in unmatched_vtts:
            vtt_base = vtt_file[:-7]

            best_match = None
            best_score = 0

            for video_file in video_files:
                video_base = os.path.splitext(video_file)[0]

                common_chars = sum(1 for c in vtt_base if c in video_base)
                score = common_chars / max(len(vtt_base), len(video_base))

                if score > best_score and score > 0.7:
                    best_match = video_file
                    best_score = score

            if best_match:
                mapping[vtt_file] = best_match
                print(f"Matched '{vtt_file}' to '{best_match}'")

    with open(os.path.join(METADATA_DIR, 'file_mapping.txt'), 'w') as f:
        for vtt, video in mapping.items():
            f.write(f"{vtt}|{video}\n")

    return mapping

file_mapping = create_file_mapping()

def process_vtt_files(file_mapping):
    annotations = {'sentences': []}
    video_captions = {}

    for vtt_file, video_file in tqdm(file_mapping.items()):
        video_id = os.path.splitext(video_file)[0]
        vtt_path = os.path.join(VTT_CAPTIONS_DIR, vtt_file)

        try:
            captions = webvtt.read(vtt_path)

            video_captions[video_id] = []

            for caption in captions:
                start_time = caption.start
                end_time = caption.end
                text = caption.text.strip()

                start_seconds = convert_time_to_seconds(start_time)
                end_seconds = convert_time_to_seconds(end_time)

                caption_entry = {
                    'video_id': video_id,
                    'start_time': start_time,
                    'end_time': end_time,
                    'start_seconds': start_seconds,
                    'end_seconds': end_seconds,
                    'caption': text
                }

                video_captions[video_id].append(caption_entry)

                annotations['sentences'].append({
                    'video_id': video_id,
                    'caption': text,
                    'start': start_seconds,
                    'end': end_seconds
                })

        except Exception as e:
            print(f"Error processing {vtt_file}: {str(e)}")

    annotations_file = os.path.join(ANNOTATIONS_DIR, 'annotations.json')
    with open(annotations_file, 'w') as f:
        json.dump(annotations, f, indent=2)

    for video_id, captions in video_captions.items():
        video_captions_file = os.path.join(ANNOTATIONS_DIR, f'{video_id}_captions.json')
        with open(video_captions_file, 'w') as f:
            json.dump(captions, f, indent=2)
    return annotations

def convert_time_to_seconds(time_str):
    h, m, s = time_str.split(':')
    return int(h) * 3600 + int(m) * 60 + float(s)

video_files = [f for f in os.listdir(RAW_VIDEOS_DIR) if f.endswith(('.mp4', '.avi', '.mkv'))]

def create_dataset_splits(annotations, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, min_test=35):
    video_ids = list(set([s['video_id'] for s in annotations['sentences']]))

    video_metadata = {}

    for video_id in video_ids:
        video_captions = [s for s in annotations['sentences'] if s['video_id'] == video_id]

        team_mentions = []
        for caption in video_captions:
            team_pattern = r'\b(Bucks|Lakers|Celtics|Warriors|Heat|Spurs|Bulls|Rockets)\b'
            teams_found = re.findall(team_pattern, caption['caption'])
            team_mentions.extend(teams_found)

        actions = {
            'shot': sum(1 for c in video_captions if 'shot' in c['caption'].lower()),
            'rebound': sum(1 for c in video_captions if 'rebound' in c['caption'].lower()),
            'assist': sum(1 for c in video_captions if 'assist' in c['caption'].lower()),
            'block': sum(1 for c in video_captions if 'block' in c['caption'].lower()),
            'steal': sum(1 for c in video_captions if 'steal' in c['caption'].lower())
        }

        video_metadata[video_id] = {
            'teams': list(set(team_mentions)),
            'actions': actions,
            'caption_count': len(video_captions)
        }

    import random
    random.seed(42)

    shuffled_videos = video_ids.copy()
    random.shuffle(shuffled_videos)

    n_videos = len(shuffled_videos)
    test_size = max(int(n_videos * test_ratio), min_test)

    remaining = n_videos - test_size

    n_train = int(remaining * (train_ratio / (train_ratio + val_ratio)))
    n_val = remaining - n_train

    train_videos = shuffled_videos[:n_train]
    val_videos = shuffled_videos[n_train:n_train+n_val]
    test_videos = shuffled_videos[n_train+n_val:]

    splits = {
        'train': train_videos,
        'val': val_videos,
        'test': test_videos
    }

    splits_file = os.path.join(METADATA_DIR, 'splits.json')
    with open(splits_file, 'w') as f:
        json.dump(splits, f, indent=2)

    print(f"Created dataset splits: {len(train_videos)} train, {len(val_videos)} val, {len(test_videos)} test")
    return splits

def resume_processing(file_mapping, splits):
    processed_videos = set()
    state_file = os.path.join(METADATA_DIR, 'processed_videos.txt')
    if os.path.exists(state_file):
        with open(state_file, 'r') as f:
            processed_videos = set(f.read().splitlines())

    unprocessed = {
        'train': [v for v in splits['train'] if v not in processed_videos],
        'val': [v for v in splits['val'] if v not in processed_videos],
        'test': [v for v in splits['test'] if v not in processed_videos]
    }

    quotas = {
        'train': min(5, len(unprocessed['train'])),
        'val': min(5, len(unprocessed['val'])),
        'test': min(30, len(unprocessed['test']))
    }

    to_process = {
        'train': unprocessed['train'][:quotas['train']],
        'val': unprocessed['val'][:quotas['val']],
        'test': unprocessed['test'][:quotas['test']]
    }

    return to_process, processed_videos

from transformers import ViTFeatureExtractor, ViTModel

def load_yolo_model():
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
    model.classes = [0, 32]
    return model

def load_vit_model():
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
    model = ViTModel.from_pretrained('google/vit-base-patch16-224', add_pooling_layer=False)
    return feature_extractor, model

class FeatureExtractor:
    def __init__(self, yolo_model, vit_extractor, vit_model, file_mapping):
        self.yolo = yolo_model
        self.vit_extractor = vit_extractor
        self.vit_model = vit_model
        self.file_mapping = file_mapping

        self.video_id_to_file = {}
        for vtt_file, video_file in file_mapping.items():
            video_id = os.path.splitext(video_file)[0]
            self.video_id_to_file[video_id] = video_file

        self.processed_videos = set()
        state_file = os.path.join(METADATA_DIR, 'processed_videos.txt')
        if os.path.exists(state_file):
            with open(state_file, 'r') as f:
                self.processed_videos = set(f.read().splitlines())

    def extract_frames(self, video_path, sample_rate=8):
        frames = []
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)

        interval = max(1, int(fps / sample_rate))
        frame_count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % interval == 0:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)

            frame_count += 1

        cap.release()
        return frames

    def detect_objects(self, frame):
        pil_image = Image.fromarray(frame)
        results = self.yolo(pil_image)
        return results

    def extract_ball_features(self, frame, detections):
        ball_detections = detections.xyxy[0][detections.xyxy[0][:, 5] == 32]

        if len(ball_detections) == 0:
            return np.zeros(768)

        best_ball = ball_detections[torch.argmax(ball_detections[:, 4])]
        x1, y1, x2, y2 = best_ball[:4].int().cpu().numpy()

        h, w = frame.shape[:2]
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(w, x2), min(h, y2)

        if x1 >= x2 or y1 >= y2:
            return np.zeros(768)

        ball_crop = frame[y1:y2, x1:x2]

        ball_crop_pil = Image.fromarray(ball_crop)
        inputs = self.vit_extractor(images=ball_crop_pil, return_tensors="pt")

        with torch.no_grad():
            outputs = self.vit_model(**inputs)

        ball_features = outputs.last_hidden_state[:, 0].cpu().numpy()[0]

        return ball_features

    def extract_player_features(self, frame, detections):
        player_detections = detections.xyxy[0][detections.xyxy[0][:, 5] == 0]

        max_players = 5
        player_features = np.zeros((max_players, 768))

        for i, player in enumerate(player_detections[:max_players]):
            x1, y1, x2, y2 = player[:4].int().cpu().numpy()

            h, w = frame.shape[:2]
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = min(w, x2), min(h, y2)

            if x1 >= x2 or y1 >= y2:
                continue

            player_crop = frame[y1:y2, x1:x2]

            player_crop_pil = Image.fromarray(player_crop)
            inputs = self.vit_extractor(images=player_crop_pil, return_tensors="pt")

            with torch.no_grad():
                outputs = self.vit_model(**inputs)

            player_features[i] = outputs.last_hidden_state[:, 0].cpu().numpy()[0]

        return player_features

    def extract_basket_features(self, frame, detections):
        h, w = frame.shape[:2]
        upper_frame = frame[:h//3, :]

        basket_crop_pil = Image.fromarray(upper_frame)
        inputs = self.vit_extractor(images=basket_crop_pil, return_tensors="pt")

        with torch.no_grad():
            outputs = self.vit_model(**inputs)

        basket_features = outputs.last_hidden_state[:, 0].cpu().numpy()[0]

        return basket_features

    def generate_court_segmentation(self, frame):
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        edges = cv2.Canny(blurred, 50, 150)

        edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
        edges_pil = Image.fromarray(edges_rgb)

        inputs = self.vit_extractor(images=edges_pil, return_tensors="pt")

        with torch.no_grad():
            outputs = self.vit_model(**inputs)

        court_features = outputs.last_hidden_state[:, 0].cpu().numpy()[0]

        return court_features

    def extract_features(self, video_id, max_frames=100):
        if video_id in self.processed_videos:
            print(f"Video {video_id} already processed. Skipping.")
            return True

        print(f"Processing video {video_id}")

        video_file = self.video_id_to_file[video_id]
        video_path = os.path.join(RAW_VIDEOS_DIR, video_file)

        frames = self.extract_frames(video_path)
        if len(frames) == 0:
            print(f"Failed to extract frames from {video_path}")
            return False

        frames = frames[:max_frames]
        n_frames = len(frames)

        timesformer_features = np.zeros((n_frames, 768))
        ball_features = np.zeros((n_frames, 768))
        player_features = np.zeros((n_frames, 5, 768))
        basket_features = np.zeros((n_frames, 768))
        court_features = np.zeros((n_frames, 768))

        for i, frame in enumerate(tqdm(frames)):
            detections = self.detect_objects(frame)

            ball_features[i] = self.extract_ball_features(frame, detections)
            player_features[i] = self.extract_player_features(frame, detections)
            basket_features[i] = self.extract_basket_features(frame, detections)
            court_features[i] = self.generate_court_segmentation(frame)

            timesformer_features[i] = np.random.randn(768) * 0.1

            if (i + 1) % 10 == 0:
                np.save(os.path.join(FEATURES_DIR, 'ball', f"{video_id}_temp.npy"), ball_features[:i+1])
                np.save(os.path.join(FEATURES_DIR, 'player', f"{video_id}_temp.npy"), player_features[:i+1])
                np.save(os.path.join(FEATURES_DIR, 'basket', f"{video_id}_temp.npy"), basket_features[:i+1])
                np.save(os.path.join(FEATURES_DIR, 'court', f"{video_id}_temp.npy"), court_features[:i+1])

        np.save(os.path.join(FEATURES_DIR, 'timesformer', f"{video_id}.npy"), timesformer_features)
        np.save(os.path.join(FEATURES_DIR, 'ball', f"{video_id}.npy"), ball_features)
        np.save(os.path.join(FEATURES_DIR, 'player', f"{video_id}.npy"), player_features)
        np.save(os.path.join(FEATURES_DIR, 'basket', f"{video_id}.npy"), basket_features)
        np.save(os.path.join(FEATURES_DIR, 'court', f"{video_id}.npy"), court_features)

        for feature_type in ['ball', 'player', 'basket', 'court']:
            temp_file = os.path.join(FEATURES_DIR, feature_type, f"{video_id}_temp.npy")
            if os.path.exists(temp_file):
                os.remove(temp_file)

        self.processed_videos.add(video_id)
        with open(os.path.join(METADATA_DIR, 'processed_videos.txt'), 'a') as f:
            f.write(f"{video_id}\n")

        return True

def create_captions_csv(annotations):

    captions_data = []

    for sentence in annotations['sentences']:
        video_id = sentence['video_id']
        caption = sentence['caption']

        captions_data.append({
            'video_id': video_id,
            'caption': caption,
            'feature_file': f"{video_id}.npy"
        })

    df = pd.DataFrame(captions_data)
    df.to_csv(os.path.join(METADATA_DIR, 'captions.csv'), index=False)
    print(f"Exported {len(captions_data)} captions to CSV")

def main():
    file_mapping = create_file_mapping()

    annotations = process_vtt_files(file_mapping)

    splits = create_dataset_splits(annotations)

    create_captions_csv(annotations)

    to_process, processed_videos = resume_processing(file_mapping, splits)

    yolo_model = load_yolo_model()
    vit_extractor, vit_model = load_vit_model()

    extractor = FeatureExtractor(yolo_model, vit_extractor, vit_model, file_mapping)
    extractor.processed_videos = processed_videos

    train_videos = to_process['train']

    for i, video_id in enumerate(train_videos):
        success = extractor.extract_features(video_id)

        if (i + 1) % 5 == 0:
            print("Clearing memory cache")
            torch.cuda.empty_cache()
            import gc
            gc.collect()

    val_videos = to_process['val']

    for i, video_id in enumerate(val_videos):
        success = extractor.extract_features(video_id)

        if (i + 1) % 5 == 0:
            print("Clearing memory cache")
            torch.cuda.empty_cache()
            import gc
            gc.collect()

    test_videos = to_process['test']

    for i, video_id in enumerate(test_videos):
        success = extractor.extract_features(video_id)

        if (i + 1) % 5 == 0:
            print("Clearing memory cache")
            torch.cuda.empty_cache()
            import gc
            gc.collect()

if __name__ == "__main__":
    main()