In [1]:
import torch
import pandas as pd
from torch.nn.functional import cosine_similarity
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from scipy.special import softmax
import os

In [2]:
# Threshold for similarity
similarity_threshold = 0.3
#index for narrration qualitative visualization
narr_index = 0
# Duration of each video feature
feature_duration = 2.1333
#Whether to do grounding or qualitatively display similar segments
qual_vis = False
#number of high similarity segments for comparison (qual_vis only)
top_n = 5

In [3]:
# Convert timestamps to seconds for plotting
def timestamp_to_seconds(timestamp):
    return timestamp.dt.hour * 3600 + timestamp.dt.minute * 60 + timestamp.dt.second + timestamp.dt.microsecond / 1e6

In [4]:
def generate_frame_based_grounding(narrations_df, similarity_scores, feature_duration):
    narrations = []
    start_times = []
    end_times = []
    for idx, row in narrations_df.iterrows():
        start = row['narration_seconds']
        end = row['narration_seconds']
        
        while True:
            expand_left = expand_right = False
            left_index = int(start / feature_duration) - 1
            right_index = int(end / feature_duration) + 1
            
            if left_index >= 0 and similarity_scores[idx][left_index] >= similarity_threshold:
                start -= feature_duration
                expand_left = True
            
            if right_index < similarity_scores.shape[1] and similarity_scores[idx][right_index] >= similarity_threshold:
                end += feature_duration
                expand_right = True
            
            # Break the loop if neither side can be expanded
            if not expand_left and not expand_right:
                break
        narrations.append(row['narration_id'])
        start_times.append(start)
        end_times.append(end)
        """ print(f"Expanded Ground Truth Interval: {start:.2f} to {end:.2f} seconds")
        print(f"Ground Truth Narration: {row['narration']} at {row['narration_timestamp']}")
        print() """
    return narrations, start_times, end_times

In [5]:
def generate_similarity_scores(narration_features, video_features, use_qual_vis=False):
    if qual_vis:
        similarity_scores = cosine_similarity(narration_features, video_features, dim=1)
    else:
        dot_product = torch.matmul(narration_features, video_features.transpose(-1, -2))
        narration_norm = torch.linalg.norm(narration_features, dim=-1, keepdim=True)
        video_norm = torch.linalg.norm(video_features, dim=-1, keepdim=True)
        similarity_scores = torch.div(dot_product, torch.mul(narration_norm, video_norm.T))
    return similarity_scores

In [6]:
def generate_qual_vis(narrations_df, similarity_scores):
    top_indices = torch.topk(similarity_scores, top_n).indices.flatten()
    start_times = top_indices * feature_duration
    end_times = start_times + feature_duration
    print("Current narration:", narrations_df.iloc[narr_index]['narration'])
    print(f"Ground Truth Interval: {narrations_df.iloc[narr_index]['start_seconds']:.2f} to {narrations_df.iloc[narr_index]['stop_seconds']:.2f} seconds")
    for start, end in zip(start_times.numpy().flatten(), end_times.numpy().flatten()):
        overlapping_narrations = narrations_df[(narrations_df['start_seconds'] <= end) & (narrations_df['stop_seconds'] >= start)]
        print(f"High Similarity Interval: {start:.2f} to {end:.2f} seconds")
        if not overlapping_narrations.empty:
            print("Overlapping Ground Truth Narrations:")
            for _, row in overlapping_narrations.iterrows():
                print(f"  - {row['narration']} from {row['start_timestamp']} to {row['stop_timestamp']}")
        else:
            print("  - No overlapping narrations.")
        print()

In [7]:
#Process narrations
narrations_df = pd.read_csv('/private/home/arjunrs1/epic-kitchens-100-annotations/EPIC_100_train.csv')
narrations_df = narrations_df[narrations_df.video_id=="P01_01"] #TODO: change from hardcoded
narrations_df['start_seconds'] = timestamp_to_seconds(pd.to_datetime(narrations_df['start_timestamp']))
narrations_df['stop_seconds'] = timestamp_to_seconds(pd.to_datetime(narrations_df['stop_timestamp']))
narrations_df['narration_timestamp'] = pd.to_datetime(narrations_df['narration_timestamp'])
narrations_df['narration_seconds'] = timestamp_to_seconds(narrations_df['narration_timestamp'])
narrations_df = narrations_df[narrations_df.stop_seconds<=1620] #TODO: Change from hardcoded
narrations_df.reset_index(inplace=True, drop=True)

#Load video features
video_features = torch.load("/private/home/arjunrs1/EgoVLPv2/EgoVLPv2/video_features_ng/00:00:01.089.pt") #TODO: change from hardcoded

#load narration_features
if qual_vis:
    first_narration_id = narrations_df.iloc[0]['narration_id']
    narration_features = torch.load(f'/private/home/arjunrs1/EgoVLPv2/EgoVLPv2/narration_features_ng/{first_narration_id}.pt').cpu()
else:
    narration_features = []
    for _, row in narrations_df.iterrows():
        narration_features.append(torch.load(f'/private/home/arjunrs1/EgoVLPv2/EgoVLPv2/narration_features_ng/{row.narration_id}.pt').cpu())
    narration_features = torch.concat(narration_features)
narration_features = narration_features.unsqueeze(0) if narration_features.dim() == 1 else narration_features

  narrations_df['start_seconds'] = timestamp_to_seconds(pd.to_datetime(narrations_df['start_timestamp']))
  narrations_df['stop_seconds'] = timestamp_to_seconds(pd.to_datetime(narrations_df['stop_timestamp']))
  narrations_df['narration_timestamp'] = pd.to_datetime(narrations_df['narration_timestamp'])


In [8]:
similarity_scores = generate_similarity_scores(narration_features, video_features, use_qual_vis=qual_vis)

In [9]:
if qual_vis:
    generate_qual_vis(narrations_df, similarity_scores)
else:
    narrations, start_times, end_times = generate_frame_based_grounding(narrations_df, similarity_scores, feature_duration)

In [10]:
narration_grounded = pd.DataFrame({
    'narration_id': narrations,
    'start_seconds': start_times,
    'stop_seconds': end_times
})

In [11]:
#Post-processing to format as the audio_grounded narrations df is:
narration_grounded['assigned_intervals'] = narration_grounded.apply(lambda row: [[row['start_seconds'], row['stop_seconds']]], axis=1)
narration_grounded = narration_grounded.drop(['start_seconds', 'stop_seconds'], axis=1)
merged_df = pd.merge(narrations_df, narration_grounded, on='narration_id')
merged_df = merged_df[['start_timestamp', 'stop_timestamp', 'narration', 'start_seconds', 'stop_seconds', 'assigned_intervals']]
merged_df = merged_df.sort_values(by='start_timestamp')
merged_df.reset_index(inplace=True, drop=True)

In [12]:
video_grounded_narrs_output_dir = "video_grounded_narrations"
video_grounded_narrs_filename = f"similarity_threshold={similarity_threshold}_feature_duration={feature_duration}.pkl"
video_grounded_narrations_filepath = os.path.join("/private/home/arjunrs1/epic-sounds-annotations", video_grounded_narrs_output_dir, video_grounded_narrs_filename)
merged_df.to_pickle(video_grounded_narrations_filepath)

In [13]:
# We have computed EgoVLP video features at ~2 second resolution from the video (non-overlapping), and computed EgoVLP narration features. Then we verified qualitatively that regions with high narr-video
#similarity corresponded to regions where the narration was active. Therefore, we have verified that the narration/video embeddings are proper. Below are the steps to further improve this:

#TODO: Modify video feature generation:
    #1) Create separate .csv file with short (1 second) overlapping windows (0.5 second) on each row. No other cols
    # are needed in this .csv file. 
    #2) Load in this csv file and load in each 1 second window of video, using num_segments=16 to subsample the video
    #into 16 frames. NOTE: Ensure that you check that this is 16 frames in the generate_*.py script.
    #3) Use EK-100 finetuned checkpoint instead of pre-trained checkpoint.
#TODO: Show that with more fine-grained features, we immprove grounding (going from coarse 2 second features to more-finegrained 1 or 1.5 second features).