In [None]:
import cv2
import torch
import numpy as np
import torchvision.transforms.functional as tr
import torchvision.transforms.v2.functional as trv2
import pandas as pd
import csv
import random
from os import listdir, makedirs
from os.path import isfile, join

In [None]:
dest_dir = "/home/tyler/Downloads/NumaGuard-main/data/videos_2/"
src_dir = "/home/tyler/Downloads/NumaGuard-main/data/videos/"
csv_save_dir = f"{dest_dir}data.csv"

In [None]:
IMAGE_WIDTH = 192
IMAGE_HEIGHT = 256
FRAME_CAP = 100
FRAME_CAP = 30
FPS = 200

In [None]:
def pad_frames(frames):
    target = torch.zeros(FRAME_CAP, 3, IMAGE_WIDTH, IMAGE_HEIGHT)

    stopping_point = frames.shape[0]

    target[:stopping_point, :, :, :] = frames

    return target

In [None]:
def truncate_frames(frames):
    return frames[:FRAME_CAP, :, :, :]

In [None]:
def preprocess_image(image):
    image = torch.tensor(image)

    image = image.permute(2,0,1)

    image = tr.resize(image, (IMAGE_WIDTH,IMAGE_HEIGHT), antialias=True) / 255

    image = trv2.equalize(image)

    image = image.unsqueeze(0)

    return image

In [None]:
def preprocess_image_preview_version(image):
    image = torch.tensor(image)

    image = image.permute(2,0,1)

    image = tr.resize(image, (IMAGE_WIDTH,IMAGE_HEIGHT), antialias=True)
    
    image = trv2.equalize(image)

    image = image.permute(1,2,0)
    
    image = image.numpy()
    
    return image

In [None]:
def preview_frames(fps: int, src: str):

    frames = torch.zeros((0,IMAGE_WIDTH,IMAGE_HEIGHT,3))
    count = 0

    vidcap = cv2.VideoCapture(src)

    success,image = vidcap.read()

    image = preprocess_image_preview_version(image)

    frames = torch.cat([frames, torch.tensor(image).unsqueeze(0)])
    input_masks = torch.zeros((FRAME_CAP), dtype=torch.int8)

    makedirs("./tests", exist_ok=True)

    while success:
        vidcap.set(cv2.CAP_PROP_POS_MSEC,(count*fps))

        success,image = vidcap.read()

        if not success:
            break

        image = preprocess_image_preview_version(image)

        cv2.imwrite(f"./tests/{count}.jpg", image)

        frames = torch.cat([frames, torch.tensor(image).unsqueeze(0)])
        input_masks[count] = 1

        count += 1

    print(frames)

    print(f"Frames Shape: {frames.shape}")

    print(input_masks)
    
    print(f"Input Masks Shape: {input_masks.shape}")

In [None]:
def get_frames(fps: int, src: str):

    frames = torch.zeros((0,3,IMAGE_WIDTH,IMAGE_HEIGHT))
    count = 0

    vidcap = cv2.VideoCapture(src)

    success,image = vidcap.read()

    image = preprocess_image(image)
    frames = torch.cat([frames, image])

    input_mask = torch.zeros((FRAME_CAP), dtype=torch.int8)

    while success:
        vidcap.set(cv2.CAP_PROP_POS_MSEC,(count*fps))

        success,image = vidcap.read()

        if not success:
            break

        image = preprocess_image(image)
        frames = torch.cat([frames, image])
        input_mask[count] = 1

        count += 1
        
    if frames.shape[0] > FRAME_CAP:
        # frames = truncate_frames(frames)
        raise Exception
    
    elif frames.shape[0] < FRAME_CAP:
        frames = pad_frames(frames)

    return frames, input_mask

In [None]:
def create_record(files):
    files = [str(x) for x in files]
    return ",".join(files)

In [None]:
def write_csv(records):
    with open(csv_save_dir, 'w', newline="\n") as myfile:
        wr = csv.writer(myfile, delimiter='\n', quotechar="", quoting=csv.QUOTE_NONE)
        wr.writerow(records)
        myfile.close()

In [None]:
def preprocess_batch():
    records = []

    records.append("file_name,pin")
    valid_records = 0
    invalid_records = 0

    onlyfiles = [f for f in listdir(src_dir) if isfile(join(src_dir, f))]

    for i, file_name in enumerate(onlyfiles):
        try:
            print(f"i: {i} file_name: {file_name}", end="\r")
            frames, input_mask = get_frames(FPS, f"{src_dir}{file_name}")

            pin = file_name.split("_")[0]
    
            record = create_record([i, pin])

            records.append(record)

            torch.save(frames, f"{dest_dir}{i}")
            torch.save(input_mask, f"{dest_dir}{i}_mask")

            valid_records += 1
        except:
            invalid_records += 1
            continue

    write_csv(records)

    print()

    print(f"Valid Records: {valid_records}\tInvalid Records: {invalid_records}")

In [None]:
def balance_data():
    selected = set()
    df = pd.read_csv(f"{src_dir}clean.csv", header=0)
    made_shots = df[df["shot_made"] == 1].to_numpy()
    missed_shots = df[df["shot_made"] == 0].to_numpy()


    n_df = made_shots.copy()

    print(n_df.shape)

    for i in range(len(made_shots)):
        idx = 0
        
        while True:
            idx = random.randrange(len(missed_shots))
            
            if idx not in selected:
                selected.add(idx)
                break

        vals = np.expand_dims(missed_shots[idx], 0)
        n_df = np.append(n_df, vals, axis=0)

    n_df = pd.DataFrame.from_records(n_df)
    n_df.to_csv(f"{src_dir}clean_2.csv", index=False)

    print(n_df.shape)

    print(selected)

In [None]:
preprocess_batch()

In [None]:
# preview_frames(FPS, "/home/tyler/Downloads/NumaGuard-main/data/videos/4021_20240402_160547_748042.mp4")

In [None]:
# balance_data()