In [1]:
import os
import glob
import numpy as np
import torch
from torch_geometric.data import Data
from torchvision import models, transforms
from PIL import Image
from sklearn.neighbors import NearestNeighbors


In [2]:

def get_edge_info(num_frame):
    skip = skip_factor
    node_source = []
    node_target = []
    edge_attr = []
    for i in range(num_frame):
        for j in range(num_frame):
            frame_diff = i - j
            if abs(frame_diff) <= tauf:
                node_source.append(i)
                node_target.append(j)
                edge_attr.append(np.sign(frame_diff))
            elif skip:
                if (frame_diff % skip == 0) and (abs(frame_diff) <= skip * tauf):
                    node_source.append(i)
                    node_target.append(j)
                    edge_attr.append(np.sign(frame_diff))
    return node_source, node_target, edge_attr


In [3]:

def extract_features(image_path):
    resnet = models.resnet50(pretrained=True)
    modules = list(resnet.children())[:-1]
    resnet = torch.nn.Sequential(*modules)
    resnet.eval()

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    img = Image.open(image_path).convert('RGB')
    img_tensor = preprocess(img).unsqueeze(0)

    with torch.no_grad():
        features = resnet(img_tensor)

    return features.view(features.size(0), -1).numpy()


In [4]:

def generate_temporal_graph(image_folder):
    image_files = sorted(glob.glob(os.path.join(image_folder, '*.png')))
    num_frame = len(image_files)

    features = []
    
    print(f"Extracting features from {num_frame} images in {image_folder}...")
    
    # Extract features for each image file
    for img_file in image_files:
        feature_vector = extract_features(img_file)
        features.append(feature_vector)

    features = np.array(features).squeeze()

    # Get edge information within the same video
    node_source, node_target, edge_attr = get_edge_info(num_frame)

    return features, node_source, node_target, edge_attr


In [5]:

def connect_videos_with_knn(all_graphs):
    all_features = []
    
    # Collect all features from all videos
    for features, _, _, _ in all_graphs:
        all_features.append(features)
    
    all_features = np.vstack(all_features)  # Combine features from all videos
    
    # Create a k-NN model to find nearest neighbors across videos
    knn_model = NearestNeighbors(n_neighbors=5)  # You can adjust n_neighbors as needed
    knn_model.fit(all_features)

    edges_from_other_videos = []

    # Create edges based on nearest neighbors
    start_index = 0  # To keep track of where each video's features start in the combined array

    print("Connecting images between videos using k-NN...")
    
    for video_index, (video_features, _, _, _) in enumerate(all_graphs):
        num_video_frames = video_features.shape[0]
        
        # Find nearest neighbors for each feature vector in the current video across all videos
        distances, indices = knn_model.kneighbors(video_features)

        for i in range(num_video_frames):
            for neighbor_index in indices[i]:
                # Ensure that we do not connect nodes within the same video
                if neighbor_index // num_video_frames != video_index:
                    edges_from_other_videos.append((start_index + i, neighbor_index))

        start_index += num_video_frames

        print(f"Processed video {video_index + 1}/{len(all_graphs)}: {num_video_frames} frames")

    return edges_from_other_videos


In [6]:

root_data = './data/longstick'
tauf = 5
skip_factor = 10

dataset_folder = '/tmp/xirl/datasets/xmagical/train/longstick'
output_graphs_path = os.path.join(root_data, 'graphs')
os.makedirs(output_graphs_path, exist_ok=True)

all_graphs = []


In [7]:

# Iterate through each video folder (0/, 1/, etc.)
video_folders = os.listdir(dataset_folder)
total_videos = len(video_folders)

for video_index, video_folder in enumerate(video_folders):
    full_video_path = os.path.join(dataset_folder, video_folder)
    
    if os.path.isdir(full_video_path):
        features, node_source, node_target, edge_attr = generate_temporal_graph(full_video_path)
        
        # Store graph data as a tuple (features, source nodes, target nodes, edge attributes)
        all_graphs.append((features, node_source, node_target, edge_attr))
        
        print(f"Finished processing video {video_index + 1}/{total_videos}: {video_folder}")


Extracting features from 34 images in /tmp/xirl/datasets/xmagical/train/longstick/326...




Finished processing video 1/878: 326
Extracting features from 29 images in /tmp/xirl/datasets/xmagical/train/longstick/99...
Finished processing video 2/878: 99
Extracting features from 28 images in /tmp/xirl/datasets/xmagical/train/longstick/448...
Finished processing video 3/878: 448
Extracting features from 50 images in /tmp/xirl/datasets/xmagical/train/longstick/868...
Finished processing video 4/878: 868
Extracting features from 72 images in /tmp/xirl/datasets/xmagical/train/longstick/465...
Finished processing video 5/878: 465
Extracting features from 40 images in /tmp/xirl/datasets/xmagical/train/longstick/136...
Finished processing video 6/878: 136
Extracting features from 78 images in /tmp/xirl/datasets/xmagical/train/longstick/323...
Finished processing video 7/878: 323
Extracting features from 30 images in /tmp/xirl/datasets/xmagical/train/longstick/683...
Finished processing video 8/878: 683
Extracting features from 48 images in /tmp/xirl/datasets/xmagical/train/longstick/4

In [8]:
# Connect images between videos using k-NN
edges_between_videos = connect_videos_with_knn(all_graphs)

# Combine all graphs into one large graph
combined_features = np.vstack([graph[0] for graph in all_graphs])  # Combine feature vectors
combined_node_source = []
combined_node_target = []
combined_edge_attr = []

# Add edges from individual videos to combined graph
for idx, (features, node_source, node_target, edge_attr) in enumerate(all_graphs):
    offset = sum([graph[0].shape[0] for graph in all_graphs[:idx]])  # Calculate offset for current video
    combined_node_source.extend([source + offset for source in node_source])
    combined_node_target.extend([target + offset for target in node_target])
    combined_edge_attr.extend(edge_attr)

# Add inter-video edges to combined graph
combined_node_source.extend([source for source, _ in edges_between_videos])
combined_node_target.extend([target for _, target in edges_between_videos])

# Create final graph data object with combined data
final_graph_data = Data(
    x=torch.tensor(combined_features, dtype=torch.float32),
    edge_index=torch.tensor([combined_node_source, combined_node_target], dtype=torch.long),
)

# Save the final combined graph data
torch.save(final_graph_data, os.path.join(output_graphs_path, 'combined_graph.pt'))

print('Combined graph generation completed and saved as combined_graph.pt.')

Connecting images between videos using k-NN...
Processed video 1/877: 34 frames
Processed video 2/877: 29 frames
Processed video 3/877: 28 frames
Processed video 4/877: 50 frames
Processed video 5/877: 72 frames
Processed video 6/877: 40 frames
Processed video 7/877: 78 frames
Processed video 8/877: 30 frames
Processed video 9/877: 48 frames
Processed video 10/877: 42 frames
Processed video 11/877: 26 frames
Processed video 12/877: 32 frames
Processed video 13/877: 32 frames
Processed video 14/877: 40 frames
Processed video 15/877: 25 frames
Processed video 16/877: 60 frames
Processed video 17/877: 43 frames
Processed video 18/877: 49 frames
Processed video 19/877: 47 frames
Processed video 20/877: 36 frames
Processed video 21/877: 40 frames
Processed video 22/877: 57 frames
Processed video 23/877: 31 frames
Processed video 24/877: 26 frames
Processed video 25/877: 29 frames
Processed video 26/877: 42 frames
Processed video 27/877: 36 frames
Processed video 28/877: 39 frames
Processed 