In [21]:
import numpy as np
import cv2
import networkx as nx
import math
import matplotlib.pyplot as plt
import torch
import json
import os
import glob

from PIL import Image
from transformers import pipeline
from our_utils import *

video_name = 'bulk_process_1/Antananarivo_Madagascar - Antananarivo Madagascar walking tour [Antananarivo Madagascar walking tour] - [oVuj3mueH2o]'

print(video_name)

pattern = glob.escape(video_name) + "/*.json"
file = glob.glob(pattern)
file.sort()
#print(file)
file = file[2]
print(file)

with open(f"{file}", "r") as fp:
    data = json.load(fp)
    
cmap = plt.get_cmap("tab20")
colormap = plt.cm.get_cmap("tab20", 20)

N = 20
colors = [cmap(i) for i in range(N)]
palette = (np.array(colors)[:, :3] * 255).astype(np.uint8)

hash_table = {}

device = 'cuda' if torch.cuda.is_available() else 'cpu'

video_path = data['path'] # UPDATED
#print(video_path)
cap = cv2.VideoCapture(video_path) # UPDATED
spacing  = data['spacing'] # UPDATED

toProcess = data['frames']

for ind, frame_data in enumerate(toProcess):
    
    frame_id = frame_data['frame_id']
    detections = frame_data['detections']

    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id*spacing) # UPDATED

    ret, image = cap.read()    
    frame = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
    coord3 = np.zeros((1000,image.shape[1],3),dtype=np.uint8)+255
    coord2 = np.zeros((1000,1000,3),dtype=np.uint8)+255
    coord1 = np.zeros((1000,image.shape[1],3),dtype=np.uint8)+255

    G = nx.DiGraph()
    #G.add_edges_from(edges)

    ### UPDATED ####
    
    components = frame_data.get("components_60", [])
    for comp in components:
        for i in range(len(comp)):
            for j in range(i + 1, len(comp)):
                G.add_edge(comp[i], comp[j])
    
    ### UPDATED ####
    
    groups = np.arange(1000)
    for idx, comp in enumerate(components, 1):
        
        key = "-".join(str(c) for c in comp)
        for c in comp:
            groups[c] = idx

        if key not in hash_table:
            hash_table[key]=1
        else:
            hash_table[key]+=1

    for n in G.nodes():
        G.nodes[n]['pos'] = (0,0)
        
    for int_id, det in enumerate(detections):
        
        x1, y1, x2, y2 = map(int, det['bbox'])
        track_id = det['track_id']
        #print(track_id)
        class_name = 'person'
        label = f"{class_name} ID:{track_id}"
        rgb_tuple = (int(palette[track_id%20][0]),int(palette[track_id%20][1]),int(palette[track_id%20][2]))

        o1_mid = ((x1+x2)//2, (y1+y2)//2) # updated
        d1 = int(det['depth']) # updated
        
        X, Y, Z = det['3D_60FOV']

        h1, w1 = coord1.shape[:2]
        h2, w2 = coord2.shape[:2]
        h3, w3 = coord3.shape[:2]

        center1 = (int((X+2)*100), 1000-int(Z*100)) # updated
        center2 = (int((X+2)*100), int(d1)*4) # updated
        center3 = (int(o1_mid[0]), int(d1)*4) # updated
        
        direction = det['direction']
        offset = 25
        
        plot_coord(coord1, center1, direction, offset, rgb_tuple)
        plot_coord(coord2, center2, direction, offset, rgb_tuple)
        plot_coord(coord3, center3, direction, offset, rgb_tuple)
        
        if class_name == 'person':
            draw_tracking(frame=image, bbox = det['bbox'], label=label, color=rgb_tuple)
            
        #print(G.nodes)
        if track_id in G.nodes:
            G.nodes[track_id]['pos'] = (int(o1_mid[0]), -int(d1)*4)
        #else:
        #    G.nodes[track_id]['pos'] = (0,0)
        #node_positions[track_id] = center3

    
        
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    coord1 = cv2.cvtColor(coord1, cv2.COLOR_BGR2RGB)
    coord2 = cv2.cvtColor(coord2, cv2.COLOR_BGR2RGB)
    coord3 = cv2.cvtColor(coord3, cv2.COLOR_BGR2RGB)
    
    pos = nx.get_node_attributes(G, "pos")

    #pos = nx.spring_layout(G, seed=42)
        
    node_colors = [colormap(n%20) for n in G.nodes()]

    fig, ax = plt.subplots(figsize=(4, 4))
    nx.draw(
        G, pos, ax=ax,
        with_labels=True,
        node_color=node_colors,
        edge_color="gray",
        node_size=100  
    )
    ax.set_axis_off()
    plt.tight_layout()
    
    fig.canvas.draw()
    graph = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    graph = graph.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # 4 channels
    graph = graph[:,:,:3]
    graph = cv2.cvtColor(graph, cv2.COLOR_BGR2RGB)
    plt.close(fig)

    combined_frame1 = np.hstack([cv2.resize(image, (600, 320)), cv2.resize(coord2, (600, 320))])    
    combined_frame2 = np.hstack([cv2.resize(coord3, (600, 320)), cv2.resize(graph, (600, 320))])
    combined_frame = np.vstack([combined_frame1, combined_frame2])

    plt.figure(figsize=(36, 8))
    plt.imshow(combined_frame)
    plt.axis("off")
    plt.title(f"Frame {ind}")
    plt.show()

cap.release()    
print(hash_table)

bulk_process_1/Antananarivo_Madagascar - Antananarivo Madagascar walking tour [Antananarivo Madagascar walking tour] - [oVuj3mueH2o]
bulk_process_1/Antananarivo_Madagascar - Antananarivo Madagascar walking tour [Antananarivo Madagascar walking tour] - [oVuj3mueH2o]/Antananarivo_001_0003.json


  colormap = plt.cm.get_cmap("tab20", 20)


{'2-4': 1, '8-9-10': 1, '2-4-6': 1, '2-9': 2, '20-22': 2, '2-12': 4, '3-90': 7, '88-89': 4, '87-95': 1, '3-90-99': 4, '87-98': 2, '88-89-103': 2, '12-88-103': 2, '87-90-99': 4, '87-99': 4, '3-87-90-99': 1, '89-109': 3, '116-117': 1, '87-90': 3, '3-87-90': 5, '89-140': 1, '99-116-133-144': 2, '99-133': 1, '99-116-133': 1}


In [22]:
posible_groups = sorted(hash_table.items(), key= lambda item:item[1], reverse=True)
print(posible_groups)

def draw_raw(frame, detection):
    x1, y1, x2, y2 = map(int, detection)
    cv2.rectangle(frame, (x1, y1), (x2, y2), (255,0,0) , 2)
    cv2.putText(frame, 'person', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 2)
    
with open(f"{file}", "r") as fp:
    data = json.load(fp)
    
frames = data['current_total_frames']
spacing = data['spacing']

n = int(input("Enter number of groups to visualize: "))

first_ids_str, _ = posible_groups[0]
max_frames = 40

rows = int(math.sqrt(n))
cols = int(math.ceil(n/rows))
group_frames = []

for group_idx, (ids_str, _) in enumerate(posible_groups[:n]):
    
    ids = [int(x) for x in ids_str.split('-')]
    cap = cv2.VideoCapture(data['path'])
    cap.set(cv2.CAP_PROP_POS_FRAMES, data['start_frame'])
    c_frames = 0
    frame = 0
    j = 0
    saved_images = []

    while c_frames < (frames-1)*spacing and len(saved_images) < max_frames:

        ret, frame = cap.read()
        if not ret:
            break
            
        if c_frames % spacing == 0:
            track_ids = [det['track_id'] for det in data['frames'][j]['detections']]
            show = False
            for idx in ids:
                if idx in track_ids:
                    ind = track_ids.index(idx)
                    draw_raw(frame, data['frames'][j]['detections'][ind]['bbox'])
                    show = True
            if show:
                saved_images.append(cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),(0, 0), fx=0.5, fy=0.5))
            j += 1

        c_frames += 1
    cap.release()

    while len(saved_images) < max_frames:
        if saved_images:
            saved_images.append(np.zeros_like(saved_images[0]))    
    group_frames.append(saved_images)
    
rows = int(math.sqrt(n))
cols = int(math.ceil(n / rows))

for frame_idx in range(max_frames):
    grid_image = np.vstack([
        np.hstack([
            group_frames[r * cols + c][frame_idx] if r * cols + c < len(group_frames)
            else np.zeros_like(group_frames[0][0]) #np.zeros((1080, 1920, 3), dtype=np.uint8)
            for c in range(cols)
        ])
        for r in range(rows)
    ])
    
    plt.figure(figsize=(16, 9))
    plt.imshow(grid_image)
    plt.axis('off')
    plt.title(f"Frame {frame_idx+1} of {max_frames}")
    plt.show()

[('3-90', 7), ('3-87-90', 5), ('2-12', 4), ('88-89', 4), ('3-90-99', 4), ('87-90-99', 4), ('87-99', 4), ('89-109', 3), ('87-90', 3), ('2-9', 2), ('20-22', 2), ('87-98', 2), ('88-89-103', 2), ('12-88-103', 2), ('99-116-133-144', 2), ('2-4', 1), ('8-9-10', 1), ('2-4-6', 1), ('87-95', 1), ('3-87-90-99', 1), ('116-117', 1), ('89-140', 1), ('99-133', 1), ('99-116-133', 1)]


Enter number of groups to visualize:  4
