In [None]:
import math
from matplotlib import pyplot as plt
import numpy as np
import cv2
import pandas as pd
import soundfile as sf
import os
from pathlib import Path
from sklearn.metrics import confusion_matrix

import torch

from cnn_model import HeatmapFusionCNN
from getAudioSaliency import compute_audio_saliency_heatmap_vectorized, precompute_integrals
from getVideoLabels import filterDf, getModeTileIndex
from getVideoSaliency import compute_video_saliency_heatmap_vectorized

def normalize_heatmaps(heatmaps):
    """Normalize heatmap to [0, 1] range."""
    # returns a list of mins and maxs for each heatmap
    h_mins = np.min(heatmaps, axis=(1, 2), keepdims=True)
    h_maxs = np.max(heatmaps, axis=(1, 2), keepdims=True)

    return (heatmaps - h_mins) / (h_maxs - h_mins)


def getFrame(cap, output_height, output_width, frame_idx):    
    """
    Read video and yield resized frames.
    """
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()        
    resized_frame = cv2.resize(frame, (output_width, output_height), interpolation=cv2.INTER_LINEAR)

    return resized_frame

def tile_index_to_coords(idx, numCols):
    """Convert linear index to tile coordinates"""
    y = idx // numCols
    x = idx % numCols
    return x, y

def tile_distance(pred_idx, true_idx, numCols):
    """Calculate tile distance"""
    px, py = tile_index_to_coords(pred_idx, numCols)
    tx, ty = tile_index_to_coords(true_idx, numCols)
    
    # Wrap horizontally
    dx = abs(px - tx)
    dx = min(dx, numCols - dx)
    
    # Don't wrap vertically
    dy = abs(py - ty)
    
    distance = (dx**2 + dy**2) ** 0.5
    return distance

def printAndWriteLine(printedLine, file):
    file.write(printedLine + "\n")
    print(printedLine)


def process_360_video(video_name, video_path, audio_path, output_path, model_path,
                      csv_path, erp_height=1920, erp_width=3840, 
                      sample_every_n_frames=5, numHeatmaps=7,
                      cols = 16, rows = 9, device = "cpu"):
    """
    Main pipeline to process a 360 video and extract audio saliency heatmaps.
    
    Parameters:
        video_path: path to ERP format 360 video
        audio_path: path to first-order ambisonic audio file
        output_path: where to save the output .npy file
        erp_height: height of ERP format (pixels)
        erp_width: width of ERP format (pixels)
        sample_every_n_frames: sample every N frames
    """
    
    # Load audio
    print("Loading ambisonic audio...")
    audio_data, audio_samplerate = sf.read(audio_path)
    
    # Check for 4 channels
    if len(audio_data.shape) == 1:
        raise ValueError(f"Audio is mono. Expected 4-channel first-order ambisonics.")
    elif audio_data.shape[1] != 4:
        raise ValueError(f"Audio has {audio_data.shape[1]} channels. Expected 4-channel first-order ambisonics (W, X, Y, Z).")
    
    # Split into channels
    W = audio_data[:, 0]
    X = audio_data[:, 1]
    Y = audio_data[:, 2]
    Z = audio_data[:, 3]
    
    print(f"Audio shape: {audio_data.shape}")
    print(f"Audio sample rate: {audio_samplerate} Hz")
    print("Successfully loaded 4-channel first-order ambisonics audio")
    
    # Open video to get metadata
    print("Opening video...")
    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        raise ValueError(f"Could not open video: {video_path}")
    
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    print(f"Video FPS: {video_fps}")
    print(f"Total frames: {total_frames}")
    print(f"Video dimensions: {video_width}x{video_height}")
    
    # Check if resizing is needed
    need_resize = video_width != erp_width or video_height != erp_height
    if need_resize:
        print(f"Video will be resized from {video_width}x{video_height} to {erp_width}x{erp_height}")
    
    # Precompute integrals for coarse tiles (20x20 degrees)
    tile_cache = precompute_integrals(tile_size_deg=20)
    
    # Calculate number of sampled frames
    # num_sampled_frames = (total_frames - math.ceil(sample_every_n_frames / 2)) // sample_every_n_frames
    num_sampled_frames = 150    

    # num_sampled_frames = 3

    (labelDf, participants) = filterDf(csv_path, video_name, video_name)
        
    print(f"Processing {num_sampled_frames} frames...")

    print(f"Loading model...")
    # Load the model state
    model = HeatmapFusionCNN()  # Create a new model instance first
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)  # Move to appropriate device
    print(f"Model loaded!")

    numCorrect = 0
    numTotal = 0

    predictedLabels = []
    trueLabels = []
    totalDistance = 0

    with open(output_path, 'w') as file:

            
        # Use frame generator (resizes all frames upfront in the stream). Also, only retrieves them one at a time, instead of keeping it all in memory
        for sampled_frame_idx in range(num_sampled_frames):
            frame_idx = sample_every_n_frames * (sampled_frame_idx + 1)

            prevFrame = getFrame(cap, erp_height, erp_width, frame_idx - 1)
            frame = getFrame(cap, erp_height, erp_width, frame_idx)
            
            printedLine = f"Processing frame {frame_idx}/{total_frames} (sample {sampled_frame_idx}/{num_sampled_frames})"
            printAndWriteLine(printedLine, file)
            
            # Compute audio saliency heatmap
            saliency_heatmaps = np.concatenate([compute_audio_saliency_heatmap_vectorized(W, X, Y, Z, audio_samplerate,
                                                                            frame_idx, video_fps,
                                                                            erp_height, erp_width,
                                                                            tile_cache, sample_every_n_frames,
                                                                            numHeatmaps-2, tile_size_deg=20),
                                                                            compute_video_saliency_heatmap_vectorized(prevFrame, frame, frame_idx, video_fps,
                                                                                                                erp_height, erp_width,
                                                                                                                tile_cache, sample_every_n_frames,
                                                                                                                numHeatmaps-7, tile_size_deg=20)], axis=0
                                                                            )
            
            # Normalize heatmap
            saliency_heatmaps = normalize_heatmaps(saliency_heatmaps)

            heatmaps = torch.from_numpy(saliency_heatmaps).float().to(device)

            # Run inference
            with torch.no_grad():
                outputs = model(heatmaps.unsqueeze(0))
                predicted_tile = outputs[0].argmax(dim=0).item()

            targetTime = frame_idx / video_fps

            actual_tile = getModeTileIndex(targetTime, labelDf, participants, rows, cols)

            printedLine = f"Predicted tile was {predicted_tile}, actual tile was {actual_tile}!"
            printAndWriteLine(printedLine, file)

            predictedLabels.append(predicted_tile)
            trueLabels.append(actual_tile)

            if(predicted_tile == actual_tile):
                numCorrect += 1

            numTotal += 1

            printedLine = f"Num correct thus far is {numCorrect}, num total thus far is {numTotal}"
            printAndWriteLine(printedLine, file)

            distance = tile_distance(predicted_tile, actual_tile, cols)

            printedLine = f"Euclidean distance from true was {distance}"
            printAndWriteLine(printedLine, file)

            totalDistance += distance

            if device.type == 'cuda':
                torch.cuda.empty_cache()
            del heatmaps, outputs  # After you've extracted predicted_tile
        
        classes_present = np.unique(np.concatenate([predictedLabels, trueLabels]))
        
        cm = confusion_matrix(trueLabels, predictedLabels, labels=classes_present)

        # Format with labels
        cm_str = f"Confusion Matrix (Predicted vs True):\n"
        cm_str += f"Classes: {classes_present}\n"
        cm_str += str(cm)

        printAndWriteLine(cm_str, file)

        printedLine = f"Avg distance was: {float(totalDistance) / numTotal:.2f}"
        printAndWriteLine(printedLine, file)

        printedLine = f"Accuracy was: {float(numCorrect) / numTotal:.2f}"
        printAndWriteLine(printedLine, file)
        
        cap.release()


    
if __name__ == "__main__":
    os.chdir("./../..")
    
    # Configuration - modify as needed
    ERP_WIDTH = 1920  # width
    ERP_HEIGHT = 960  # height
    SAMPLE_RATE = 5  # sample every 5 frames
    FILE_NAME = "0004"
     
    VIDEO_PATH = f"Data/Pre-Processed-Data/{FILE_NAME}/{FILE_NAME}_mono_60fps.mp4"  # ERP format 360 video
    AUDIO_PATH = f"Data/Pre-Processed-Data/{FILE_NAME}/{FILE_NAME}.wav"
    INPUT_CSV_PATH = f"Data/Pre-Processed-Data/head_data/head_video_{FILE_NAME}.csv"
    OUTPUT_PATH = f"FinalTestingResults/{FILE_NAME}_Results.txt"
    MODEL_PATH = f"cnn_model.pth"
    NUM_HEATMAPS = 9
    TILE_COLS = 16
    TILE_ROWS = 9

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Run the pipeline
    process_360_video(FILE_NAME, VIDEO_PATH, AUDIO_PATH, OUTPUT_PATH, MODEL_PATH, INPUT_CSV_PATH,
                                      erp_height=ERP_HEIGHT, erp_width=ERP_WIDTH,
                                      sample_every_n_frames=SAMPLE_RATE, numHeatmaps=NUM_HEATMAPS,
                                      cols=TILE_COLS, rows=TILE_ROWS, device = DEVICE)

Repo path was: c:\Users\mahd\Documents\FOV Prediction\Scripts\finalCode\../../U-2-Net-Repo
Using device: cpu
U2 Net Model ready!
Successfully imported core RAFT modules from your fork.
Loading RAFT cpu model...
RAFT model loaded on cpu successfully.
Loading ambisonic audio...
Audio shape: (2880000, 4)
Audio sample rate: 48000 Hz
Successfully loaded 4-channel first-order ambisonics audio
Opening video...
Video FPS: 60.0
Total frames: 3598
Video dimensions: 3840x1920
Video will be resized from 3840x1920 to 1920x960
Precomputing integrals for 20° tiles...
Integral precomputation complete!
Reading CSV file: Data/Pre-Processed-Data/head_data/head_video_0004.csv
Processing 2 frames...
Loading model...
Model loaded!
Processing frame 5/3598 (sample 0/2)


  src = F.upsample(src,size=tar.shape[2:],mode='bilinear')



Trying scale 1.0...
Padded tensor shapes: t1=torch.Size([1, 3, 960, 1920]), t2=torch.Size([1, 3, 960, 1920])


  with autocast(enabled=self.args.mixed_precision):
  with autocast(enabled=self.args.mixed_precision):
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  with autocast(enabled=self.args.mixed_precision):


✓ Flow computed successfully at scale 1.0
before fusion: 0.0 1.0
After fusion: -0.8454725742340088 1.0559604167938232
After conv1: -0.7022640109062195 0.7890627980232239
After conv2: -4.248074054718018 4.467042446136475
After conv3: -2.8335306644439697 3.3739211559295654
After pool: 0.0 1.8778979778289795
After fc1: -1.508062720298767 5.066052436828613
After fc2 (output): -4.561481952667236 4.552745342254639
Predicted tile was 75, actual tile was 65!
Num correct thus far is 0, num total thus far is 1
Euclidean distance from true was 6.0
Processing frame 10/3598 (sample 1/2)


  src = F.upsample(src,size=tar.shape[2:],mode='bilinear')



Trying scale 1.0...
Padded tensor shapes: t1=torch.Size([1, 3, 960, 1920]), t2=torch.Size([1, 3, 960, 1920])


  with autocast(enabled=self.args.mixed_precision):
  with autocast(enabled=self.args.mixed_precision):
  with autocast(enabled=self.args.mixed_precision):


✓ Flow computed successfully at scale 1.0
before fusion: 0.0 1.0
After fusion: -0.8078544735908508 0.9146549701690674
After conv1: -0.6568062901496887 0.6594387888908386
After conv2: -3.9136910438537598 2.9638423919677734
After conv3: -2.5999231338500977 3.427844285964966
After pool: 0.0 2.238872766494751
After fc1: -1.2006548643112183 4.9988694190979
After fc2 (output): -4.796249866485596 8.050904273986816
Predicted tile was 77, actual tile was 66!
Num correct thus far is 0, num total thus far is 2
Euclidean distance from true was 5.0
Confusion Matrix (Predicted vs True):
Classes: [65 66 75 77]
[[0 0 1 0]
 [0 0 0 1]
 [0 0 0 0]
 [0 0 0 0]]
Avg distance was: 5.50
Accuracy was: 0.00
